[英]NumPy Array Fill Rows Downward By Indexed Sections
Let's say I have the following (fictitious) NumPy array:假设我有以下(虚构的)NumPy 数组:
arr = np.array(
[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24],
[25, 26, 27, 28],
[29, 30, 31, 32],
[33, 34, 35, 36],
[37, 38, 39, 40]
]
)
And for row indices idx = [0, 2, 3, 5, 8, 9]
I'd like to repeat the values in each row downward until it reaches the next row index:对于行索引
idx = [0, 2, 3, 5, 8, 9]
我想向下重复每一行中的值,直到它到达下一个行索引:
np.array(
[[1, 2, 3, 4],
[1, 2, 3, 4],
[9, 10, 11, 12],
[13, 14, 15, 16],
[13, 14, 15, 16],
[21, 22, 23, 24],
[21, 22, 23, 24],
[21, 22, 23, 24],
[33, 34, 35, 36],
[37, 38, 39, 40]
]
)
Note that idx
will always be sorted and have no repeat values.请注意,
idx
将始终被排序并且没有重复值。 While I can accomplish this by doing something like:虽然我可以通过执行以下操作来完成此操作:
for start, stop in zip(idx[:-1], idx[1:]):
for i in range(start, stop):
arr[i] = arr[start]
# Handle last index in `idx`
start, stop = idx[-1], arr.shape[0]
for i in range(start, stop):
arr[i] = arr[start]
Unfortunately, I have many, many arrays like this and this can become slow as the size of the array gets larger (in both the number of rows as well as the number of columns) and the length of idx
also increases.不幸的是,我有很多很多这样的 arrays 并且随着数组的大小变大(行数和列数)变大并且
idx
的长度也增加,这可能会变慢。 The final goal is to plot these as a heatmaps in matplotlib
, which I already know how to do.最终目标是将 plot 这些作为
matplotlib
中的热图,我已经知道该怎么做。 Another approach that I tried was using np.tile
:我尝试的另一种方法是使用
np.tile
:
for start, stop in zip(idx[:-1], idx[1:]):
reps = max(0, stop - start)
arr[start:stop] = np.tile(arr[start], (reps, 1))
# Handle last index in `idx`
start, stop = idx[-1], arr.shape[0]
arr[start:stop] = np.tile(arr[start], (reps, 1))
But I am hoping that there's a way to get rid of the slow for-loop
.但我希望有办法摆脱缓慢
for-loop
。
Try np.diff
to find the repetition for each row, then np.repeat
:尝试
np.diff
找到每一行的重复,然后np.repeat
:
# this assumes `idx` is a standard list as in the question
np.repeat(arr[idx], np.diff(idx+[len(arr)]), axis=0)
Output: Output:
array([[ 1, 2, 3, 4],
[ 1, 2, 3, 4],
[ 9, 10, 11, 12],
[13, 14, 15, 16],
[13, 14, 15, 16],
[21, 22, 23, 24],
[21, 22, 23, 24],
[21, 22, 23, 24],
[33, 34, 35, 36],
[37, 38, 39, 40]])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.