[英]Numpy repeat for 2d array
给定两个数组,说
arr = array([10, 24, 24, 24, 1, 21, 1, 21, 0, 0], dtype=int32)
rep = array([3, 2, 2, 0, 0, 0, 0, 0, 0, 0], dtype=int32)
np.repeat(arr,rep)返回
array([10, 10, 10, 24, 24, 24, 24], dtype=int32)
有没有办法为一组2D阵列复制此功能?
那是给定的
arr = array([[10, 24, 24, 24, 1, 21, 1, 21, 0, 0],
[10, 24, 24, 1, 21, 1, 21, 32, 0, 0]], dtype=int32)
rep = array([[3, 2, 2, 0, 0, 0, 0, 0, 0, 0],
[2, 2, 2, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)
是否可以创建矢量化的功能?
PS:每一行中的重复次数不必相同。 我将每个结果行填充以确保它们的大小相同。
def repeat2d(arr, rep):
# Find the max length of repetitions in all the rows.
max_len = rep.sum(axis=-1).max()
# Create a common array to hold all results. Since each repeated array will have
# different sizes, some of them are padded with zero.
ret_val = np.empty((arr.shape[0], maxlen))
for i in range(arr.shape[0]):
# Repeated array will not have same num of cols as ret_val.
temp = np.repeat(arr[i], rep[i])
ret_val[i,:temp.size] = temp
return ret_val
我确实了解np.vectorize,而且与正常版本相比,它没有任何性能上的好处。
因此,每行有一个不同的重复数组吗? 但是每行的重复总数是否相同?
只需在展平的数组上repeat
上述操作,然后重新整形为正确的行数。
In [529]: np.repeat(arr,rep.flat)
Out[529]: array([10, 10, 10, 24, 24, 24, 24, 10, 10, 24, 24, 24, 24, 1])
In [530]: np.repeat(arr,rep.flat).reshape(2,-1)
Out[530]:
array([[10, 10, 10, 24, 24, 24, 24],
[10, 10, 24, 24, 24, 24, 1]])
如果每行的重复次数有所不同,则我们将填充可变长度的行。 其他SO问题也提到了这一点。 我不记得所有的细节,但是我认为解决方案是这样的:
更改rep
以便数字不同:
In [547]: rep
Out[547]:
array([[3, 2, 2, 0, 0, 0, 0, 0, 0, 0],
[2, 2, 2, 1, 0, 2, 0, 0, 0, 0]])
In [548]: lens=rep.sum(axis=1)
In [549]: lens
Out[549]: array([7, 9])
In [550]: m=np.max(lens)
In [551]: m
Out[551]: 9
创建目标:
In [552]: res = np.zeros((arr.shape[0],m),arr.dtype)
创建索引数组-需要制定详细信息:
In [553]: idx=np.r_[0:7,m:m+9]
In [554]: idx
Out[554]: array([ 0, 1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 13, 14, 15, 16, 17])
固定索引分配:
In [555]: res.flat[idx]=np.repeat(arr,rep.flat)
In [556]: res
Out[556]:
array([[10, 10, 10, 24, 24, 24, 24, 0, 0],
[10, 10, 24, 24, 24, 24, 1, 1, 1]])
类似于@hpaulj的解决方案的另一个解决方案:
def repeat2dvect(arr, rep):
lens = rep.sum(axis=-1)
maxlen = lens.max()
ret_val = np.zeros((arr.shape[0], maxlen))
mask = (lens[:,None]>np.arange(maxlen))
ret_val[mask] = np.repeat(arr.ravel(), rep.ravel())
return ret_val
我没有存储索引,而是创建了布尔掩码并使用该掩码设置值。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.