[英]How to properly stack numpy arrays?
I am having trouble understanding how data is being stacked in a numpy array and why I cannot match the last data that I added to an array with the last generated data.我无法理解数据是如何堆叠在 numpy 数组中的,以及为什么我无法将添加到数组中的最后一个数据与最后生成的数据相匹配。 Here is a MWE:
这是一个MWE:
import numpy as np
np.random.seed(1)
# build storage
container = []
# gen data
x = np.random.random((13, 1, 64, 768))
# add to container
container.append(x)
# gen data
x2 = np.random.random((13, 1, 64, 768))
# add to container
container.append(x2)
# convert to np array
container = np.asarray(container)
# reshape to [13, 2, 64, 768]
container = container.reshape(13, 2, 64, 768)
# check that the last generated data matches the last appended data
assert np.all(x2.flatten() == container[:, -1, :, :].flatten()), 'not a match'
Instead of stacking manually with appending to lists and then reshaping you could use the vstack or the concatenate function of numpy.您可以使用 vstack 或 numpy 的串联 function 来代替手动堆叠并附加到列表然后重塑。
# gen data
x1 = np.random.random((13, 1, 64, 768))
x2 = np.random.random((13, 1, 64, 768))
container = np.vstack((x1,x2))
assert np.all(x2.flatten()) == np.all(container[:, -1, :, :].flatten()), 'not a match'
To answer your question: your code does work, just make sure to put np.all()
at both sides of the comparison.要回答您的问题:您的代码确实有效,只需确保将
np.all()
放在比较的两侧。 It's always a good idea to make your input much smaller (say (2,1,2,2)) so you can see what actually happens.让你的输入更小(比如(2,1,2,2))总是一个好主意,这样你就可以看到实际发生了什么。
In [152]: alist = []
In [154]: alist.append(np.random.random((2,1,3)))
In [155]: alist.append(np.random.random((2,1,3)))
In [156]: alist
Out[156]:
[array([[[0.85221826, 0.56088315, 0.06232853]],
[[0.0966469 , 0.89513922, 0.44814579]]]),
array([[[0.86207845, 0.88895573, 0.62069196]],
[[0.11475614, 0.29473531, 0.11179268]]])]
Using np.array
to join the list elements produces a 4d array - it has joined them on a new leading dimension:使用
np.array
连接列表元素会生成一个 4d 数组 - 它已将它们连接到一个新的前导维度上:
In [157]: arr = np.array(alist)
In [158]: arr.shape
Out[158]: (2, 2, 1, 3)
In [159]: arr[-1,] # same as alist[-1]
Out[159]:
array([[[0.86207845, 0.88895573, 0.62069196]],
[[0.11475614, 0.29473531, 0.11179268]]])
If we concatenate
on one of the dimensions:如果我们在其中一个维度上
concatenate
:
In [160]: arr = np.concatenate(alist, axis=1)
In [161]: arr
Out[161]:
array([[[0.85221826, 0.56088315, 0.06232853],
[0.86207845, 0.88895573, 0.62069196]],
[[0.0966469 , 0.89513922, 0.44814579],
[0.11475614, 0.29473531, 0.11179268]]])
In [162]: arr.shape
Out[162]: (2, 2, 3) # note the shape - that 2nd 2 is the join axis
In [163]: arr[:,-1]
Out[163]:
array([[0.86207845, 0.88895573, 0.62069196],
[0.11475614, 0.29473531, 0.11179268]])
[163] has the same numbers as [159], but a (2,3) shape. [163] 与 [159] 具有相同的数字,但形状为 (2,3)。
reshape
keeps the values, but may 'shuffle' them: reshape
保留值,但可能会“洗牌”它们:
In [164]: np.array(alist).reshape(2,2,3)
Out[164]:
array([[[0.85221826, 0.56088315, 0.06232853],
[0.0966469 , 0.89513922, 0.44814579]],
[[0.86207845, 0.88895573, 0.62069196],
[0.11475614, 0.29473531, 0.11179268]]])
We have transpose the leading 2 axes before reshape to match [161]我们在 reshape 之前转置了前 2 个轴以匹配 [161]
In [165]: np.array(alist).transpose(1,0,2,3)
Out[165]:
array([[[[0.85221826, 0.56088315, 0.06232853]],
[[0.86207845, 0.88895573, 0.62069196]]],
[[[0.0966469 , 0.89513922, 0.44814579]],
[[0.11475614, 0.29473531, 0.11179268]]]])
In [166]: np.array(alist).transpose(1,0,2,3).reshape(2,2,3)
Out[166]:
array([[[0.85221826, 0.56088315, 0.06232853],
[0.86207845, 0.88895573, 0.62069196]],
[[0.0966469 , 0.89513922, 0.44814579],
[0.11475614, 0.29473531, 0.11179268]]])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.