繁体   English   中英

如何在训练/评估期间访问Tensorflow中标签的值?

[英]How do I access the value of a label in Tensorflow during training/eval?

我正在使用Tensorflow中的神经网络,使用自定义丢失函数,该函数基于输入图像的标签,我采样矢量(来自与标签相对应的另一个数据集)并获取输入图像嵌入的点积( softmax预激活)和采样矢量。 我还负面地采样与输入标签冲突的不匹配矢量,以及其标签与当前输入标签不同的另一个随机训练输入图像嵌入。 显然这些都被构造为同维张量,并且自定义损失是:

Loss = max(0, input_embedding * mismatching_sample - input_embedding * matching_sample + 1) \ 
+ max(0, random_embedding * matching_sample - input_embedding * matching_sample + 1)

这主要是为了背景,但我遇到的问题是如何访问与输入图像对应的标签值? 我需要能够访问这些标签的值,以便采样正确的向量并计算我的损失。

我在文档中读到你可以使用.eval()来通过在会话中运行来获取张量的值,但是当我尝试这个时我的终端只是挂了...技术上我在训练时已经在运行一个会话我的神经网络,所以不确定是否存在在另一个会话中运行第二个会话并尝试评估技术上属于另一个正在运行的会话的值的问题。 无论如何,我完全没有关于如何使这项工作的想法。 任何帮助将不胜感激!

这是我原来的尝试证明有问题:

# compute custom loss function using tensors and tensor operations
def compute_loss(e_list, labels, i):
    embedding = e_list[i] #getting the current embedding tensor
    label = labels[i] #getting the matching label tensor, this is value I need.
    y_index = np.nonzero(label)[0][0] #this always returns 0, doesn't work :(
    target = get_mnist_embedding(y_index)
    wrong_mnist = get_mismatch_mnist_embedding(y_index)
    wrong_spec = get_random_spec_embedding(y_index, e_list, labels)
    # compute the loss:
    zero = tf.constant(0,dtype="float32")
    one = tf.constant(1,dtype="float32")
    mul1 = tf.mul(wrong_mnist,embedding)
    dot1 = tf.reduce_sum(mul1)
    mul2 = tf.mul(target,embedding)
    dot2 = tf.reduce_sum(mul2)
    mul3 = tf.mul(target,wrong_spec)
    dot3 = tf.reduce_sum(mul3)
    max1 = tf.maximum(zero, tf.add_n([dot1, tf.negative(dot2), one]))
    max2 = tf.maximum(zero, tf.add_n([dot3, tf.negative(dot2), one]))
    loss = tf.add(max1,max2)
    return loss

问题是我没有使用feed_dict传递采样值。 我试图在损失计算时识别输入标签的张量值,以便然后采样我的损失函数所需的其他值。

我通过预先计算和组织mismatching_samplematching_sample值来解决这个问题,例如,最初使用tf.placeholder来表示这些值,并在执行sess.run([train_op, loss...], feed_dict=feed_dict)时使用feed_dict字典对象传入值sess.run([train_op, loss...], feed_dict=feed_dict)访问我的损失计算值。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM