繁体   English   中英

Tensorflow 2:根据二维张量对 3D 张量进行排序

[英]Tensorflow 2: Sort a 3D tensor accoding to a 2D tensor

我有一个 3D 张量,具有批次、序列、特征维度(N、s、e)。 它是一系列概率分布。 然后我想根据最高预测对应的 integer 对它们进行排序。 所以说

x_probabs = 3D tensor (ex: [[[0.5, 0.1, 0.4], [0.3, 0.3, 0.4], [0.1,
0.8, 0.1]]]; # shape N s e

x = tf.argmax(x_probabs, axis=-1) = [[0, 2, 1]]; # shape N s

或者另一个例子是

x_probabs=[[[0.6, 0.1, 0.1, 0.1, 0.1], [0.1,0.1,0.1,0.1,0.6], [0.1,0.1,0.1,0.6,0.1]]];

x = [[0, 4, 3]];

如果我想订购 xi 可以做ordered_x = tf.sort(x, axis=-1) ,那么为了得到排序我可以做indices_sorted_x = tf.argsort(x, axis=-1) 我希望将相同的排序应用于 x_probabs 并且我很困惑如何做到这一点,我尝试过sorted_x_probabs = tf.gather(x_probabs, indices_sorted_x)但它不起作用,因为索引是针对二维张量而不是 3D 的。 我被困在这里。

以下是第一个示例的样子

sorted_x = [[0,1,2]];
sorted_x_probabs = [[[0.5, 0.1, 0.4],[0.1,
    0.8, 0.1],[0.3, 0.3, 0.4]]];

这将是第二个例子

sorted_x = [[0,3,4]];
sorted_x_probabs = [[[0.6, 0.1, 0.1, 0.1, 0.1],[0.1,0.1,0.1,0.6,0.1],[0.1,0.1,0.1,0.1,0.6]]];

非常感谢您提前。

您可以添加batch_dims参数以从较低维度开始收集:

x = tf.gather(x_probabs, x, batch_dims=1)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM