简体   繁体   English

如何在TensorFlow中编写argmax函数?

[英]How to write an argmax function in TensorFlow?

I'm not talking about tf.argmax but about argmax in the mathematical sense, eg given a discrete set of values and a function find the value that maximizes it. 我不是在讲tf.argmax而是在数学意义上讲argmax,例如,给定一组离散值和一个函数,找到将其最大化的值。 I currently have something like this 我目前有这样的东西

input = tf.placeholder(name='input')
Qhat = do_stuff_to(input) # e.g. tf.add(input, 3)

Now I want to define another TensorFlow node, max_Qhat , which will take an array of Tensors as its argument. 现在我想定义另一个TensorFlow节点max_Qhat ,它将一个max_Qhat数组作为其参数。 It will feed each of these tensors to Qhat and return the one that resulted in the greatest value. 它将这些张量中的每一个馈送到Qhat并返回产生Qhat张量。 How do I do this? 我该怎么做呢? ( NOTE I am not looking to run Qhat, so no session.run . I just want to define a function that evaluates it.) 注意,我不希望运行 Qhat,所以没有session.run 。我只想定义一个评估它的函数。)

Code I have so far: 我到目前为止的代码:

inputs = tf.placeholder(name='inputs')
max_Qhat = ???

Create an input tensor, say of dimensions [k, ...] (this is your "array" of k input tensors that we "maximize" over) then compute the Qhat op over the first axis/dimension of the input tensor (eg. using tf.map_fn ) such that it returns a tensor of function values of dimensions [k] and then determine the maximum of this returned tensor. 创建一个输入张量,例如尺寸[k, ...] (这是我们“最大化”的k输入张量的“数组”),然后在输入张量的第一个轴/维度上计算Qhat op(例如使用tf.map_fn ),以便它返回尺寸为[k]的函数值的张量,然后确定此返回的张量的最大值。

import tensorflow as tf;

sess = tf.InteractiveSession();

# define inputs
inputs = tf.constant([
    [1, 2, 3, 4, 5],
    [4, 3, 6, 2, 1],
    [9, 9, 9, 9, 9],
    [0, 1, 3, 5, 2]
], shape = (4, 5));

# the function/op that we want to compute for each input tensor
def some_op(t):
    return tf.reduce_sum(t);

# compute the function values
q = tf.map_fn(some_op, inputs);

# determine the index of the input tensor that maximizes the function
index = sess.run(tf.argmax(q, axis = 0));
maximizer = sess.run(inputs[index]);

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM