簡體   English   中英

Tensorflow 等效於 torch.Tensor.index_copy

[英]Tensorflow equivalent to torch.Tensor.index_copy

我正在實現等效於 model 這里最初使用實現。 一切都很順利,直到我遇到了這行特定的代碼。

batch_current = Variable(torch.zeros(size, self.embedding_dim))

# self.embedding and self.W_c are pytorch network layers I have created
batch_current = self.W_c(batch_current.index_copy(0, Variable(torch.LongTensor(index)),
                                                         self.embedding(Variable(self.th.LongTensor(current_node)))))

如果搜索index_copy的文檔,似乎它所做的只是在某個索引和公共軸上復制一組元素並將其分配給另一個張量。 但我真的不想寫一些有問題的代碼,所以在嘗試任何自我實現之前,我想知道你們是否知道我如何 go 來實現它。

model 來自本文,是的,我搜索了其他實現,但它們對我來說似乎沒有多大意義。

您需要的是tensorflow中的來獲得類似pytorch的等效操作。 下面是一個演示。

中,您有

import torch 

tensor = torch.zeros(5, 3)
indices = torch.tensor([0, 4, 2])
updates= torch.tensor([[1, 2, 3], 
                       [4, 5, 6], 
                       [7, 8, 9]], dtype=torch.float)
tensor.index_copy_(0, indices, updates)

tensor([[1., 2., 3.],
        [0., 0., 0.],
        [7., 8., 9.],
        [0., 0., 0.],
        [4., 5., 6.]])

而在 ,你可以做

import tensorflow as tf

tensor = tf.zeros([5,3])
indices = tf.constant([[0], [4], [2]])
updates  = tf.constant([[1, 2, 3], 
                        [4, 5, 6], 
                        [7, 8, 9]], dtype=tf.float32)
tensor = tf.tensor_scatter_nd_update(tensor, indices, updates)
tensor.numpy()
array([[1., 2., 3.],
       [0., 0., 0.],
       [7., 8., 9.],
       [0., 0., 0.],
       [4., 5., 6.]], dtype=float32)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM