简体   繁体   English

Tensorflow:如何使用不规则张量作为正常张量的索引?

[英]Tensorflow: How to use a Ragged Tensor as an index into a normal tensor?

I have a 2D RaggedTensor consisting of indices I want from each row of a full tensor, eg:我有一个 2D RaggedTensor,其中包含我想要的完整张量每一行的索引,例如:

[
    [0,4],
    [1,2,3],
    [5]
]

into进入

[
    [200, 305, 400, 20, 20, 105],
    [200, 315, 401, 20, 20, 167],
    [200, 7, 402, 20, 20, 105],
]

gives

[
    [200,20],
    [315,401,20],
    [105]
]

How can I achieve this in the most efficient way (preferably only with tf functions)?我怎样才能以最有效的方式实现这一点(最好只使用tf函数)? I believe that things like gather_nd are able to take RaggedTensors but I cannot figure out how it works.我相信像gather_nd这样的东西能够使用RaggedTensors,但我不知道它是如何工作的。

You can use tf.gather , with the batch_dims keyword argument:您可以使用tf.gatherbatch_dims关键字参数:

>>> tf.gather(tensor,indices,batch_dims=1)
<tf.RaggedTensor [[200, 20], [315, 401, 20], [105]]>

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM