[英]How can I get specific rows of a tensor in TensorFlow?
我有几个张量:
logits:这个张量包含最终的预测分数。
tf.Tensor 'MemN2N_1/MatMul_3:0' shape=(?, 18230) dtype=float32
最终预测计算为 predict_op = tf.argmax(logits, 1, name="predict_op")
现在我想将预测限制在一些特定的列中。 以下两个张量包含我想从中选择的列索引。
self._stories 是类型
tf.Tensor 'stories:0' shape=(?, 12, 110) dtype=int32
self._queries 是类型
tf.Tensor 'queries:0' shape=(?, 110) dtype=int32
这里的 110 列是我想限制登录的索引号。 例如,如果 logits = [[10,20,30,40,50], [10,20,30,40,50]..] 和 self._stories = [[[1,4,...], [1,2,4,...],...], [[0,4,...],[2,4...],...]...] 和 self._queries = [[1,4...],[2,4,...],...] 那么 logits 应该看起来像 [[20,30,50],[10,30,50]...]
如何在 tensorflow 中进行这种索引过滤?
你试过tf.equal
吗? 这将比较两个张量并创建一个新的张量,其中包含 True 相等的地方和 False 地方不相等。
使用这个 bool-tensor,您可以tf.select
,它从一个张量或另一个张量中选择元素,具体取决于您在第一步中创建的 bool-value。
没有深入研究您提供的特定形状,但是通过这两个操作,您可以创建您所要求的那种流。
试试tf.gather 。
row_indices = [1]
row = tf.gather(tf.constant([[1, 2],[3, 4]]), row_indices)
tf.Session().run(row) # returns [[3, 4]]
您可以使用tf.squeeze删除大小为 1 的前导维度:
row = tf.squeeze(row, squeeze_dims=0)
您可以使用tf.slice(input_, begin, size, name=None)
。 有关更多详细信息,请参阅 TensorFlow 文档: https : //www.tensorflow.org/api_docs/python/tf/slice
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.