繁体   English   中英

Integer 数组索引与广播和 Numpy 中的 alignment

[英]Integer array indexing with broadcasting and alignment in Numpy

假设我们有一个形状为(n, d)的 numpy 数组a 例如,

np.random.seed(1)

n, d = 5, 3
a = np.random.randn(n, d)

现在让indices是一个(m, n)形状的 integer 索引数组,范围在0, 1, ... d上。 也就是说,此数组包含索引 a 的第二维a索引。 例如,

m = 10
indices = np.random.randint(low=0, high=d, size=(m, n))

我想使用indices来索引a的第二维,使其与每个n和批次对齐m

我的解决方案是

result = np.vstack([a[i, :][indices[:, i]] for i in range(n)]).T
print(result.shape)
# (10, 5)

另一种解决方案是

np.diagonal(a.T[indices], axis1=1, axis2=2)

但我认为我的方法不必要地复杂。 我们是否有任何优雅的“numpitonic”广播来实现这一点,例如像aT[indices]之类的东西?

注意:“优雅的 numpitonic”的定义可能是模棱两可的。 比方说,当mn很大时最快。

也许这个:

np.take_along_axis(a.T, indices, axis=0)

它给出了正确的结果:

np.take_along_axis(a.T, indices, axis=0) == result

output:

array([[ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True]])

关于什么:

result = a[np.indices(indices.shape)[1], indices]

或者:

result = a[np.tile(np.arange(n), m), indices.ravel()].reshape(m,n)

output:

array([[-0.61175641, -1.07296862,  1.74481176,  1.46210794, -0.3224172 ],
       [ 1.62434536,  0.86540763,  0.3190391 ,  1.46210794, -0.3224172 ],
       [-0.52817175, -2.3015387 , -0.7612069 ,  1.46210794, -0.38405435],
       [ 1.62434536, -1.07296862, -0.7612069 , -0.24937038,  1.13376944],
       [ 1.62434536, -1.07296862, -0.7612069 ,  1.46210794,  1.13376944],
       [ 1.62434536, -1.07296862, -0.7612069 , -2.06014071,  1.13376944],
       [-0.61175641, -1.07296862,  0.3190391 ,  1.46210794,  1.13376944],
       [-0.61175641, -1.07296862, -0.7612069 ,  1.46210794,  1.13376944],
       [ 1.62434536, -1.07296862,  0.3190391 , -2.06014071, -0.38405435],
       [ 1.62434536, -2.3015387 ,  0.3190391 , -2.06014071, -0.3224172 ]])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM