[英]Combining slicing and broadcasted indexing for multi-dimensional numpy arrays
我有一个ND numpy数组(例如3x3x3),我想提取一个子数组,结合切片和索引数组。 例如:
import numpy as np
A = np.arange(3*3*3).reshape((3,3,3))
i0, i1, i2 = ([0,1], [0,1,2], [0,2])
ind1 = j0, j1, j2 = np.ix_(i0, i1, i2)
ind2 = (j0, slice(None), j2)
B1 = A[ind1]
B2 = A[ind2]
我希望B1 == B2,但是实际上,形状是不同的
>>> B1.shape
(2, 3, 2)
>>> B2.shape
(2, 1, 2, 3)
>>> B1
array([[[ 0, 2],
[ 3, 5],
[ 6, 8]],
[[ 9, 11],
[12, 14],
[15, 17]]])
>>> B2
array([[[[ 0, 3, 6],
[ 2, 5, 8]]],
[[[ 9, 12, 15],
[11, 14, 17]]]])
有人知道为什么吗? 关于如何仅通过操作“ A”和“ ind2”对象来获得“ B1”的任何想法吗? 目标是它适用于任何nD数组,并且我不必寻找要完全保留的尺寸形状(希望我已经很清楚了:)。 谢谢!!
- -编辑 - -
更清楚地说,我想有一个功能“有趣”,使
A[fun(ind2)] == B1
ind1
的索引子空间为(2,),(3,),(2,),结果B
为(2,3,2)
。 这是高级索引的简单情况。
ind2
是(高级)部分索引的一种情况。 有2个索引数组和1个切片。 高级索引文档指出:
如果索引子空间是分离的(按切片对象划分),则首先广播的索引空间,然后是x的切片子空间。
在这种情况下,高级索引从第一个索引和第三个索引构造一个(2,2)
数组,并在末尾附加切片尺寸,从而得到一个(2,2,3)
数组。
我会在https://stackoverflow.com/a/27097133/901925中更详细地说明推理的原因
修复像ind2
这样的元组的一种方法是将每个slice扩展成一个数组。 我最近在np.insert
看到了这一点。
np.arange(*ind2[1].indices(3))
扩展:
到[0,1,2]
。 但是替换必须具有正确的形状。
ind=list(ind2)
ind[1]=np.arange(*ind2[1].indices(3)).reshape(1,-1,1)
A[ind]
我将省略确定哪个术语是一个切片,其尺寸以及相关重塑的详细信息。 目的是复制i1
。
如果索引不是由ix_
生成的,则重塑ix_
难度可能更大。 例如
A[np.array([0,1])[None,:,None],:,np.array([0,2])[None,None,:]] # (1,2,2,3)
A[np.array([0,1])[None,:,None],np.array([0,1,2])[:,None,None],np.array([0,2])[None,None,:]]
# (3,2,2)
扩展的片必须与广播中的其他阵列兼容。
索引后交换轴是另一种选择。 但是,逻辑可能更复杂。 但是在某些情况下,移置实际上可能更简单:
A[np.array([0,1])[:,None],:,np.array([0,2])[None,:]].transpose(2,0,1)
# (3,2,2)
A[np.array([0,1])[:,None],:,np.array([0,2])[None,:]].transpose(0,2,1)
# (2, 3, 2)
这是我能更接近您的规格的信息,我无法设计出一种在不知道A
(或更确切地说,其形状...)的情况下可以计算正确索引的解决方案。
import numpy as np
def index(A, s):
ind = []
groups = s.split(';')
for i, group in enumerate(groups):
if group == ":":
ind.append(range(A.shape[i]))
else:
ind.append([int(n) for n in group.split(',')])
return np.ix_(*ind)
A = np.arange(3*3*3).reshape((3,3,3))
ind2 = index(A,"0,1;:;0,2")
print A[ind2]
较短的版本
def index2(A,s):return np.ix_(*[range(A.shape[i])if g==":"else[int(n)for n in g.split(',')]for i,g in enumerate(s.split(';'))])
ind3 = index2(A,"0,1;:;0,2")
print A[ind3]
在使用ix_
这样的受限索引情况下,可以ix_
索引。
A[ind1]
是相同的
A[i1][:,i2][:,:,i3]
由于i2
是完整范围,
A[i1][...,i3]
如果您只有ind2
A[ind2[0].flatten()][[ind2[2].flatten()]
在更一般的上下文中,您必须知道j0,j1,j2
如何相互广播,但是当它们由ix_
生成时,关系很简单。
我可以想象一下这样的情况:分配A1 = A[i1]
并随后进行涉及A1
的各种动作(包括但不限于A1[...,i3]
。 您必须知道什么时候A1
是视图,什么时候是副本。
另一个索引工具是take
:
A.take(i0,axis=0).take(i2,axis=2)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.