繁体   English   中英

如何根据列拆分 numpy 数组?

[英]How to split a numpy array based on a column?

我有一个表单数组:

[[ 1. ,    2.,     3.,     1.,     3.,     3.,     4.   ],
 [ 1.3,    2.3,    3.3,    3.,     3.3,    3.3,    4.3  ],
 [ 1.2,    2.2,    3.2,    2.,     3.2,    3.2,    4.2  ],
 [ 1.1,    2.1,    1.,     1.,     3.,     3.,     4.   ],
 [ 1.3,    2.3,    3.5,    3.,     3.3,    3.3,    4.3  ],
 [ 1.2,    2.7,    3.2,    2.,     3.2,    3.2,    4.2  ],
 [ 1.3,    2.2,    1.,     1.,     3.,     3.,     4.   ],
 [ 1.3,    2.3,    3.6,    3.,     3.3,    3.3,    4.3  ],
 [ 1.2,    2.8,    3.2,    2.,     3.2,    3.2,    4.2  ],
 [ 1.4,    2.3,    1.,     1.,     3.,     3.,     4.   ],
 [ 1.3,    2.3,    3.7,    3.,     3.3,    3.3,    4.3  ],
 [ 1.2,    2.9,    3.2,    2.,     3.2,    3.2,    4.2  ],
 [ 1.5,    2.1,    1.,     1.,     3.,     3.,     4.   ],
 [ 1.89,   2.3,    3.5,    3.,     3.3,    3.3,    4.3  ],
 [ 1.2,    2.7,    3.2,    2.,     3.2,    3.231,  4.2  ],
 [ 1.9,    2.2,    1.,     1.,     3.,     3.,     4.   ],
 [ 1.3,    2.22,   3.6,    3.,     3.3,    3.3,    4.3  ],
 [ 1.2,    2.8,    3.2,    2.,     3.66,   3.2,    4.2  ],
 [ 1.89,   2.3,    1.,     1.,     3.,     3.,     4.   ],
 [ 1.3,    2.99,   3.7,    3.,     3.3,    3.3,    4.3  ],
 [ 1.2,    2.9,    3.2,    2.,     3.34,   3.2,    4.2  ]]

我想根据第四列将此数组拆分为多个子数组。 即我想要一个第四列等于 1 的子数组,另一个第四列等于 2 的子数组,等等。我事先不知道第四列中所有可能的值是什么。

例如,第四列为 1 对应的子数组是:

[[ 1.     2.     3.     1.     3.     3.     4.   ],
 [ 1.1    2.1    1.     1.     3.     3.     4.   ],
 [ 1.3    2.2    1.     1.     3.     3.     4.   ],
 [ 1.4    2.3    1.     1.     3.     3.     4.   ],
 [ 1.5    2.1    1.     1.     3.     3.     4.   ],
 [ 1.9    2.2    1.     1.     3.     3.     4.   ],
 [ 1.89   2.3    1.     1.     3.     3.     4.   ]]

列出数组:

y = [x[x[:,3]==k] for k in np.unique(x[:,3])]

您可以使用numpy.argsortnumpy.array_splitnumpy.diffnumpy.whereO(NlogN)时间内完成此numpy.where

>>> indices = np.argsort(arr[:, 3])
>>> arr_temp = arr[indices]
>>> np.array_split(arr_temp, np.where(np.diff(arr_temp[:,3])!=0)[0]+1)
[array([[ 1.  ,  2.  ,  3.  ,  1.  ,  3.  ,  3.  ,  4.  ],
       [ 1.89,  2.3 ,  1.  ,  1.  ,  3.  ,  3.  ,  4.  ],
       [ 1.1 ,  2.1 ,  1.  ,  1.  ,  3.  ,  3.  ,  4.  ],
       [ 1.9 ,  2.2 ,  1.  ,  1.  ,  3.  ,  3.  ,  4.  ],
       [ 1.3 ,  2.2 ,  1.  ,  1.  ,  3.  ,  3.  ,  4.  ],
       [ 1.5 ,  2.1 ,  1.  ,  1.  ,  3.  ,  3.  ,  4.  ],
       [ 1.4 ,  2.3 ,  1.  ,  1.  ,  3.  ,  3.  ,  4.  ]]), array([[ 1.2  ,  2.8  ,  3.2  ,  2.   ,  3.66 ,  3.2  ,  4.2  ],
       [ 1.2  ,  2.7  ,  3.2  ,  2.   ,  3.2  ,  3.231,  4.2  ],
       [ 1.2  ,  2.9  ,  3.2  ,  2.   ,  3.2  ,  3.2  ,  4.2  ],
       [ 1.2  ,  2.9  ,  3.2  ,  2.   ,  3.34 ,  3.2  ,  4.2  ],
       [ 1.2  ,  2.8  ,  3.2  ,  2.   ,  3.2  ,  3.2  ,  4.2  ],
       [ 1.2  ,  2.7  ,  3.2  ,  2.   ,  3.2  ,  3.2  ,  4.2  ],
       [ 1.2  ,  2.2  ,  3.2  ,  2.   ,  3.2  ,  3.2  ,  4.2  ]]), array([[ 1.3 ,  2.3 ,  3.6 ,  3.  ,  3.3 ,  3.3 ,  4.3 ],
       [ 1.89,  2.3 ,  3.5 ,  3.  ,  3.3 ,  3.3 ,  4.3 ],
       [ 1.3 ,  2.3 ,  3.5 ,  3.  ,  3.3 ,  3.3 ,  4.3 ],
       [ 1.3 ,  2.22,  3.6 ,  3.  ,  3.3 ,  3.3 ,  4.3 ],
       [ 1.3 ,  2.3 ,  3.3 ,  3.  ,  3.3 ,  3.3 ,  4.3 ],
       [ 1.3 ,  2.99,  3.7 ,  3.  ,  3.3 ,  3.3 ,  4.3 ],
       [ 1.3 ,  2.3 ,  3.7 ,  3.  ,  3.3 ,  3.3 ,  4.3 ]])]

我将@ashwini-chaudhary 的想法转变为返回感兴趣的索引以供以后迭代的方式。 所以我想我会分享它:

def split_idx_by_dim(dim_array):
    """Returns a sequence of arrays of indices of elements sharing the same value in dim_array"""
    idx = np.argsort(dim_array)
    sorted_cl_ids = dim_array[idx]
    split_idx = np.array_split(idx, np.where(np.diff(sorted_cl_ids) != 0)[0] + 1)
    return split_idx

查看将数组拆分为多个子数组的文档

numpy.hsplit(ary,indexs_or_sections)

水平(按列)将一个数组拆分为多个子数组。

假设您有一个4x4阵列A:

array([[  0.,   1.,   2.,   3.],
   [  4.,   5.,   6.,   7.],
   [  8.,   9.,  10.,  11.],
   [ 12.,  13.,  14.,  15.]])

split = numpy.hsplit(A,4) = 

[array([[  0.],
   [  4.],
   [  8.],
   [ 12.]]), array([[  1.],
   [  5.],
   [  9.],
   [ 13.]]), array([[  2.],
   [  6.],
   [ 10.],
   [ 14.]]), array([[  3.],
   [  7.],
   [ 11.],
   [ 15.]])]

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM