繁体   English   中英

如何使用PyTorch Tensor.index_select()?

[英]How to use PyTorch Tensor.index_select()?

当前正在尝试使用PyTorch实现REINFORCE算法。 我希望能够在取消奖励后收集负责任的产出。 因此,给定操作内存,我创建了索引的Tensor,并尝试使用Tensor.index_select,但没有成功。 有人可以帮忙吗?

    rH = np.array(rH) # discounted reward
    aH = np.array(aH) # action_holder
    sH = np.vstack(np.array(sH)) # states holder

    statesTensor = Variable(torch.from_numpy(sH).type(torch.FloatTensor))
    out = model.forward(statesTensor)

    indexes = GuiltyOnes(out, aH)
    flat = out.view(1,-1)

    respos = torch.index_select(flat, 1, torch.from_numpy(indexes).type(torch.LongTensor))

我收到以下错误:

    return IndexSelect.apply(self, dim, index)
    RuntimeError: save_for_backward can only save input or output tensors, but argument 0 doesn't satisfy this condition

您的情况可能与类似,因此,应改用Variable

i = Variable(torch.from_numpy(indexes).long())
respos = torch.index_select(flat, 1, i)

请记住,pytorch错误消息并不总是很准确。 在这种情况下,这是非常令人误解的imo。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM