[英]Pytorch tensor indexing
我目前正在将一些代码从tensorflow转换为pytorch,我遇到了tf.gather
func的问题,没有直接函数将其转换为pytorch。
我想做的基本上是索引,我有两个张量,特征为[minibatch, 60, 2]
张量形状和索引张量[minibatch, 8]
,说像第一个张量是张量A
,第二个是B
在Tensorflow中,直接使用tf.gather(A, B, batch_dims=1)
如何在pytorch中实现此目标?
我已经尝试过A[B]
索引。 这似乎不起作用
和A[0]B[0]
有效,但是shape的输出为[8, 2]
我需要[minibatch, 8, 2]
的形状
如果我像[stack, 8, 2]
8,2]那样堆叠张量,它可能会起作用[stack, 8, 2]
但是我不知道该怎么做
tensorflow
out = tf.gather(logits, indices, batch_dims=1)
pytorch
out = A[B] -> something like this will be great
输出形状为[minibatch, 8, 2]
我认为您正在寻找torch.gather
out = torch.gather(A, 1, B[..., None].expand(*B.shape, A.shape[-1]))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.