繁体   English   中英

pytorch张量索引

[英]Pytorch tensor indexing

我目前正在将一些代码从tensorflow转换为pytorch,我遇到了tf.gather func的问题,没有直接函数将其转换为pytorch。

我想做的基本上是索引,我有两个张量,特征为[minibatch, 60, 2]张量形状和索引张量[minibatch, 8] ,说像第一个张量是张量A ,第二个是B

在Tensorflow中,直接使用tf.gather(A, B, batch_dims=1)

如何在pytorch中实现此目标?

我已经尝试过A[B]索引。 这似乎不起作用

A[0]B[0]有效,但是shape的输出为[8, 2]

我需要[minibatch, 8, 2]的形状

如果我像[stack, 8, 2] 8,2]那样堆叠张量,它可能会起作用[stack, 8, 2]但是我不知道该怎么做

tensorflow
out = tf.gather(logits, indices, batch_dims=1)
pytorch
out = A[B] -> something like this will be great

输出形状为[minibatch, 8, 2]

我认为您正在寻找torch.gather

out = torch.gather(A, 1, B[..., None].expand(*B.shape, A.shape[-1]))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM