[英]TensorFlow concatenate/stack N tensors interleaving last dimensions
Assume we have 4 tensors, a
, b
, c
and d
which all share the same dimensions of (batch_size, T, C)
, we want to create a new tensor X
which has the shape (batch_size, T*4, C)
where the T*4
is interleaved looping between all of the tensors. 假设我们有4个张量a
, b
, c
和d
都具有相同的(batch_size, T, C)
尺寸,我们想创建一个新的张量X
,其形状为(batch_size, T*4, C)
其中T*4
在所有张量之间交错循环。
For example, if a
, b
, c
and d
were tensors of all ones, twos, threes and fours respectively we'd expect X
to be something like 例如,如果a
, b
, c
和d
分别是所有1、2、3和4的张量,则我们希望X
类似于
[[[1,1,1...],
[2,2,2...],
[3,3,3...],
[4,4,4...],
[1,1,1...],
[2,2,2...],
.
.
.
]]
It seems to me that your example array actually has the shape (batch_size, T, C*4)
rather than (batch_size, T*4, C)
. 在我看来,您的示例数组实际上具有(batch_size, T, C*4)
的形状(batch_size, T, C*4)
而不是(batch_size, T*4, C)
的形状。 Anyway, you can get what you need with tf.concat, tf.reshape, and tf.transpose. 无论如何,您可以使用tf.concat,tf.reshape和tf.transpose获得所需的内容。 A simpler example in 2d is as follows: 2d中的一个简单示例如下:
A = tf.ones([2,3])
B = tf.ones([2,3]) * 2
AB = tf.concat([A,B], axis=1)
AB = tf.reshape(AB, [-1, 3])
AB.eval() #array([[1., 1., 1.],
# [2., 2., 2.],
# [1., 1., 1.],
# [2., 2., 2.]], dtype=float32)
You concatenate A and B to get a matrix of shape (2,6). 您将A和B连接起来得到形状为(2,6)的矩阵。 Then you reshape it which interleaves the rows. 然后,您可以对它进行整形,使其与行交错。 To do this in 3d, the dimension which is multiplied by 4 needs to be the last one. 为此,在3d中,要乘以4的尺寸必须是最后一个尺寸。 So you may need to use tf.transpose, interleave using concat and reshape, then transpose again to reorder the dimensions. 因此,您可能需要使用tf.transpose,使用concat进行交织并整形,然后再次进行转置以重新排列尺寸。
I think another option is to use tf.tile . 我认为另一种选择是使用tf.tile 。
import tensorflow as tf
tf.enable_eager_execution()
A = tf.ones((2, 1, 4))
B = tf.ones((2, 1, 4)) * 2
C = tf.ones((2, 1, 4)) * 3
ABC = tf.concat([A, B, C], axis=1)
print(ABC)
#tf.Tensor(
#[[[1. 1. 1. 1.]
# [2. 2. 2. 2.]
# [3. 3. 3. 3.]]
#
# [[1. 1. 1. 1.]
# [2. 2. 2. 2.]
# [3. 3. 3. 3.]]], shape=(2, 3, 4), dtype=float32)
X = tf.tile(ABC, multiples=[1, 3, 1])
print(X)
#tf.Tensor(
#[[[1. 1. 1. 1.]
# [2. 2. 2. 2.]
# [3. 3. 3. 3.]
# [1. 1. 1. 1.]
# [2. 2. 2. 2.]
# [3. 3. 3. 3.]
# [1. 1. 1. 1.]
# [2. 2. 2. 2.]
# [3. 3. 3. 3.]]
#
# [[1. 1. 1. 1.]
# [2. 2. 2. 2.]
# [3. 3. 3. 3.]
# [1. 1. 1. 1.]
# [2. 2. 2. 2.]
# [3. 3. 3. 3.]
# [1. 1. 1. 1.]
# [2. 2. 2. 2.]
# [3. 3. 3. 3.]]], shape=(2, 9, 4), dtype=float32)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.