简体   繁体   English

Tensorflow 等效于 torch.Tensor.index_copy

[英]Tensorflow equivalent to torch.Tensor.index_copy

I am implementing the equivalent of the model here originally implemented using .我正在实现等效于 model 这里最初使用实现。 Everything was going smoothly until I encountered this particular line of code.一切都很顺利,直到我遇到了这行特定的代码。

batch_current = Variable(torch.zeros(size, self.embedding_dim))

# self.embedding and self.W_c are pytorch network layers I have created
batch_current = self.W_c(batch_current.index_copy(0, Variable(torch.LongTensor(index)),
                                                         self.embedding(Variable(self.th.LongTensor(current_node)))))

If search for the documentation of index_copy and it seems all it does is to copy a group of elements at a certain index and on a common axis and assign it to another tensor.如果搜索index_copy的文档,似乎它所做的只是在某个索引和公共轴上复制一组元素并将其分配给另一个张量。 But I don't really want to write some buggy code, so before attempting any self-implementation, I wish to know if you folks have an idea of how I can go about implementing it.但我真的不想写一些有问题的代码,所以在尝试任何自我实现之前,我想知道你们是否知道我如何 go 来实现它。

The model is from this paper and yes, I have searched other implementations, but they don't seem to make so much sense to me. model 来自本文,是的,我搜索了其他实现,但它们对我来说似乎没有多大意义。

What you need is the tf.tensor_scatter_nd_update in to get equivalent operation like Tensor.index_copy_ of .您需要的是tensorflow中的来获得类似pytorch的等效操作。 Here is one demonstration shown below.下面是一个演示。

In , you have中,您有

import torch 

tensor = torch.zeros(5, 3)
indices = torch.tensor([0, 4, 2])
updates= torch.tensor([[1, 2, 3], 
                       [4, 5, 6], 
                       [7, 8, 9]], dtype=torch.float)
tensor.index_copy_(0, indices, updates)

tensor([[1., 2., 3.],
        [0., 0., 0.],
        [7., 8., 9.],
        [0., 0., 0.],
        [4., 5., 6.]])

And in , you can do而在 ,你可以做

import tensorflow as tf

tensor = tf.zeros([5,3])
indices = tf.constant([[0], [4], [2]])
updates  = tf.constant([[1, 2, 3], 
                        [4, 5, 6], 
                        [7, 8, 9]], dtype=tf.float32)
tensor = tf.tensor_scatter_nd_update(tensor, indices, updates)
tensor.numpy()
array([[1., 2., 3.],
       [0., 0., 0.],
       [7., 8., 9.],
       [0., 0., 0.],
       [4., 5., 6.]], dtype=float32)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM