繁体   English   中英

Tensorflow实现的简单逻辑

[英]Tensorflow implementation for a simple logic

我是Tensorflow的新手。 我想在Tensorflow中实现一种算法,该算法具有下面描述的简单逻辑的一部分。

我有一个矩阵(矩阵的大小可以与批处理大小不同)。 我需要将矩阵中的所有值替换为零,但每一行中的最大值除外。

我想在Tensorflow API中实现这个简单的逻辑。

例如:

Input:
[[0.50041455 0.41183667 0.37627002]
 [0.57736448 0.90280652 0.70880312]
 [0.50961863 0.94126878 0.86982843]
 [0.30285231 0.6302716  0.76009756]]
Output:
[[0.50041455 0.         0.        ]
 [0.         0.9028065  0.        ]
 [0.         0.9412688  0.        ]
 [0.         0.         0.76009756]]

这是我的代码。 但我不想将占位符用于“ Y”并运行tenserflow 2次。 (我担心GPU的性能)

from __future__ import print_function
import tensorflow as tf
import numpy as np

num_classes = 3
batch_size = 4

X = tf.placeholder("float", [None, num_classes]) #My matrix

values, indices = tf.nn.top_k(X, 1)
Y = tf.placeholder("float", [None, num_classes])

M = X * Y

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    x = np.random.rand(batch_size, num_classes)
    print(x)
    ind = sess.run([indices], feed_dict={X: x})
    y = np.zeros(shape=[batch_size, num_classes], dtype=float)
    for i in range(batch_size):
        y[i, ind[0][i][0]]= 1.0
    m = sess .run(M, feed_dict={X: x, Y: y})
    print(m)

有任何想法吗?

我想这就是你想要的。

理想情况下,我们希望将行中具有最大元素的位置标记为1,将其他元素标记为0。然后,我们只需要将这两个矩阵逐元素相乘即可。

为了获得大元素的位置,我们可以使用tf.argmax来帮助我们找到相应轴上最大元素的索引,并使用tf.one_hot将上方的结果转换为与初始输入矩阵相同的形状。 有关详细behavor tf.one_hot ,您可能希望阅读https://www.tensorflow.org/api_docs/python/tf/one_hot

示例代码如下:

import tensorflow as tf

num_classes = 3

inputs = tf.placeholder(tf.float32, (None, num_classes))

# Calculate the max element in each row, -1 means the last axis
max_in_row = tf.argmax(inputs, axis=-1)

output = tf.one_hot(max_in_row, num_classes) * inputs

with tf.Session() as sess:
    print (
        sess.run(output, feed_dict={
            inputs:
            [[0.50041455, 0.41183667, 0.37627002],
             [0.57736448, 0.90280652, 0.70880312],
             [0.50961863, 0.94126878, 0.86982843],
             [0.30285231, 0.6302716,  0.76009756]]
        })
    )

输出:

[[0.50041455 0.         0.        ]
 [0.         0.9028065  0.        ]
 [0.         0.9412688  0.        ]
 [0.         0.         0.76009756]]

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM