簡體   English   中英

TensorFlow-3D張量,它從2D張量中收集每N個張量並跨越1

[英]TensorFlow - 3D tensor that gathers every Nth tensor from 2D tensor and strides 1

假設我在[M, 1]有一個2D-Tensor T

T = tf.expand_dims([A1,
     B1,
     C1,
     A2,
     B2,
     C2], 1)

我想像這樣重塑它:

T_reshp = [[[A1], [A2]]
           [[B1], [B2]]
           [[C1], [C2]]]

我事先知道MN (每組中張量的數量)。 此外,讓t_reshp.shape[0] = M/N = P在我嘗試使用tf.reshape

T_reshp = tf.reshape(T, [P, N, 1])

但是,我最終得到:

T_reshp = [[[A1], [B1]]
           [[C1], [A2]]
           [[B2], [C2]]]

我可以使用切片或整形操作來執行此操作嗎?

您可以先將其重塑為尺寸[N,P,1]然后將第一和第二條軸transpose

tf.transpose(tf.reshape(T, [N, P, 1]), [1,0,2])
#                           ^^^^ switch the two dimensions here and then transpose

范例

T = tf.expand_dims([1,2,3,4,5,6], 1)
sess = tf.Session()
T1 = tf.transpose(tf.reshape(T, [2,3,1]), [1,0,2])

sess.run(T1)
#array([[[1],
#        [4]],

#       [[2],
#        [5]],

#       [[3],
#        [6]]], dtype=int32)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM