[英]How to aggregate the values of a DataFrame (and output a numpy array quickly)?
在熊猫中给出以下DataFrame:
user item rating
1 3 2
1 4 5
2 1 5
3 5 1
3 1 3
4 4 4
4 1 1
....
我想将其传输到numpy数组,其中用户列为y轴,项目列为x轴,如下所示:
1 2 3 4 5
1 nan nan 2 5 nan
2 5 nan nan nan nan
3 3 nan nan nan 1
4 1 nan nan 4 nan
如何使用apply
功能快速完成?
您需要一个数据透视表:
>>> df.pivot_table(index='user', columns='item', values='rating')
1 3 4 5
user
1 NaN 2 5 NaN
2 5 NaN NaN NaN
3 3 NaN NaN 1
4 1 NaN 4 NaN
请注意,总共存在NaN
列; 您可以根据需要重新索引以包括它们:
>>> df.pivot_table(index='user', columns='item', values='rating')
.reindex_axis([1, 2, 3, 4, 5], axis=1)
item 1 2 3 4 5
user
1 NaN NaN 2 5 NaN
2 5 NaN NaN NaN NaN
3 3 NaN NaN NaN 1
4 1 NaN NaN 4 NaN
要将这些值放入NumPy数组中,请访问.values
属性:
_.values # _ is the last returned value in the repr
您可以使用pivot
:
print df.pivot(index='user', columns='item', values='rating')
item 1 3 4 5
user
1 NaN 2 5 NaN
2 5 NaN NaN NaN
3 3 NaN NaN 1
4 1 NaN 4 NaN
然后,您需要添加缺少的列-查找min
和max
,在reindex_axis
创建参数标签的reindex_axis
:
print df['item'].min()
1
print df['item'].max()
5
rng = range(df['item'].min(), df['item'].max() + 1)
print rng
[1, 2, 3, 4, 5]
print df.pivot(index='user',columns='item',values='rating').reindex_axis(labels=rng, axis=1)
item 1 2 3 4 5
user
1 NaN NaN 2 5 NaN
2 5 NaN NaN NaN NaN
3 3 NaN NaN NaN 1
4 1 NaN NaN 4 NaN
上次使用的values
用于生成numpy array
:
print df.pivot(index='user', columns='item', values='rating')
.reindex_axis(labels=rng, axis=1)
.values
[[ nan nan 2. 5. nan]
[ 5. nan nan nan nan]
[ 3. nan nan nan 1.]
[ 1. nan nan 4. nan]]
要快速完成,请使用numpy工具进行:
def pivotarray(df):
users,i= np.unique(df['user'],return_inverse=True)
item,j= np.unique(df['item'],return_inverse=True)
a=zeros((len(users),len(item)),int)
a[i,j]=df['rating']
return a
然后(如果需要,您可以先用NaN填充):
In [464]: pivotarray(df)
Out[464]:
array([[0, 2, 5, 0],
[5, 0, 0, 0],
[3, 0, 0, 1],
[1, 0, 4, 0]])
第2列不存在,因为没有第2项。
收益意义重大:
In [465]: %timeit pivotarray(df)
1000 loops, best of 3: 417 µs per loop
In [466]: %timeit df.pivot(index='user', columns='item', values='rating')
100 loops, best of 3: 6.38 ms per loop
In [467]: %timeit df.pivot_table(index='user', columns='item', values='rating')
100 loops, best of 3: 18.6 ms per loop
编辑
包括丢失的物品,可能是黑客:
def pivotarraywithallitems(df):
users,i= np.unique(df['user'],return_inverse=True)
item,j= np.unique(df['item'],return_inverse=True)
miss= (~in1d(arange(1,6),item)).cumsum()
j+=miss[j]
a=zeros((len(users),len(item)+miss[-1]),float)*NaN
a[i,j]=df['rating']
return a
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.