[英]Efficient multidimensional indexing in numpy
我想根据我拥有的多维索引对多维 numpy 数组进行“排序”。
让我们从一个例子开始:
# The array I would like to sort
A = np.ones([3, 10, 2])
# The index array
i = np.ones([10, 2], dtype=int)
# Bring some life into sample data
samples = np.arange(10, dtype=int)
A[2, :, 0] = A[2, :, 0] * samples
np.random.shuffle(samples)
A[2, :, 1] = A[2, :, 1] * samples
i[:, 0] = i[:, 0] * samples
np.random.shuffle(samples)
i[:, 1] = i[:, 1] * samples
所以我的数组A
包含 2 个切片,每组 10 组,每组 3 个值。 我想要做的是单独对每个切片进行排序,同时将每个集合保持在一起。 拥有索引数组i
,我的解决方案是:
A = A[:, i]
shape = A.shape
A = A.reshape([shape[0], shape[1], shape[2] + shape[3]])
A = A[:, :, [0, 3]]
我第一次使用i
来索引A
。 这将创建一个新维度,其中i
每一列都应用于A
最终得到一个形状为(4, 10, 2, 2)
的数组。 因为我只需要 4 个结果中的两个,所以我重塑了数组并删除了我不需要的信息。
这种方法效果很好,但我想知道是否有更有效或更优雅的解决方案。
朱尔兹
您可以在此处使用advanced indexing
:
A[:,i, np.arange(A.shape[2])]
与目前的方法相比 -
out = A[:, i]
shape = out .shape
out = out .reshape([shape[0], shape[1], shape[2] + shape[3]])
out = out [:, :, [0, 3]]
np.allclose(A[:,i, np.arange(A.shape[2])], out)
#True
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.