[英]Flatten 3D tensor
我有一個形狀為 T x B x N 的張量(RNN 的訓練數據,T 是最大 seq 長度,B 是批次數,N 個特征),我想跨時間步展平所有特征,這樣我得到一個形狀為 B x TN 的張量。 一直想不通這個怎么辦。。
您需要在展平之前排列軸,如下所示:
t = t.swapdims(0,1) # (T,B,N) -> (B,T,N)
t = t.view(B,-1) # (B,T,N) -> (B,T*N) (equivalent to `t.view(B,T*N)`)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.