I am implementing the tensorflow equivalent of the model here originally implemented using pytorch . Everything was going smoothly until I encountered this particular line of code.
batch_current = Variable(torch.zeros(size, self.embedding_dim))
# self.embedding and self.W_c are pytorch network layers I have created
batch_current = self.W_c(batch_current.index_copy(0, Variable(torch.LongTensor(index)),
self.embedding(Variable(self.th.LongTensor(current_node)))))
If search for the documentation of index_copy
and it seems all it does is to copy a group of elements at a certain index and on a common axis and assign it to another tensor. But I don't really want to write some buggy code, so before attempting any self-implementation, I wish to know if you folks have an idea of how I can go about implementing it.
The model is from this paper and yes, I have searched other tensorflow implementations, but they don't seem to make so much sense to me.
What you need is the tf.tensor_scatter_nd_update in tensorflow to get equivalent operation like Tensor.index_copy_ of pytorch . Here is one demonstration shown below.
In pytorch , you have
import torch
tensor = torch.zeros(5, 3)
indices = torch.tensor([0, 4, 2])
updates= torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=torch.float)
tensor.index_copy_(0, indices, updates)
tensor([[1., 2., 3.],
[0., 0., 0.],
[7., 8., 9.],
[0., 0., 0.],
[4., 5., 6.]])
And in tensorflow , you can do
import tensorflow as tf
tensor = tf.zeros([5,3])
indices = tf.constant([[0], [4], [2]])
updates = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=tf.float32)
tensor = tf.tensor_scatter_nd_update(tensor, indices, updates)
tensor.numpy()
array([[1., 2., 3.],
[0., 0., 0.],
[7., 8., 9.],
[0., 0., 0.],
[4., 5., 6.]], dtype=float32)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.