繁体   English   中英

如何在 TensorFlow 中获取张量的特定行?

[英]How can I get specific rows of a tensor in TensorFlow?

我有几个张量:

logits:这个张量包含最终的预测分数。

tf.Tensor 'MemN2N_1/MatMul_3:0' shape=(?, 18230) dtype=float32

最终预测计算为 predict_op = tf.argmax(logits, 1, name="predict_op")

现在我想将预测限制在一些特定的列中。 以下两个张量包含我想从中选择的列索引。

self._stories 是类型

tf.Tensor 'stories:0' shape=(?, 12, 110) dtype=int32

self._queries 是类型

tf.Tensor 'queries:0' shape=(?, 110) dtype=int32

这里的 110 列是我想限制登录的索引号。 例如,如果 logits = [[10,20,30,40,50], [10,20,30,40,50]..] 和 self._stories = [[[1,4,...], [1,2,4,...],...], [[0,4,...],[2,4...],...]...] 和 self._queries = [[1,4...],[2,4,...],...] 那么 logits 应该看起来像 [[20,30,50],[10,30,50]...]

如何在 tensorflow 中进行这种索引过滤?

你试过tf.equal吗? 这将比较两个张量并创建一个新的张量,其中包含 True 相等的地方和 False 地方不相等。

使用这个 bool-tensor,您可以tf.select ,它从一个张量或另一个张量中选择元素,具体取决于您在第一步中创建的 bool-value。

没有深入研究您提供的特定形状,但是通过这两个操作,您可以创建您所要求的那种流。

试试tf.gather

row_indices = [1]
row = tf.gather(tf.constant([[1, 2],[3, 4]]), row_indices)
tf.Session().run(row)  # returns [[3, 4]]

您可以使用tf.squeeze删除大小为 1 的前导维度:

row = tf.squeeze(row, squeeze_dims=0)

您可以使用tf.slice(input_, begin, size, name=None) 有关更多详细信息,请参阅 TensorFlow 文档: https : //www.tensorflow.org/api_docs/python/tf/slice

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM