簡體   English   中英

如何在訓練/評估期間訪問Tensorflow中標簽的值?

[英]How do I access the value of a label in Tensorflow during training/eval?

我正在使用Tensorflow中的神經網絡,使用自定義丟失函數,該函數基於輸入圖像的標簽,我采樣矢量(來自與標簽相對應的另一個數據集)並獲取輸入圖像嵌入的點積( softmax預激活)和采樣矢量。 我還負面地采樣與輸入標簽沖突的不匹配矢量,以及其標簽與當前輸入標簽不同的另一個隨機訓練輸入圖像嵌入。 顯然這些都被構造為同維張量,並且自定義損失是:

Loss = max(0, input_embedding * mismatching_sample - input_embedding * matching_sample + 1) \ 
+ max(0, random_embedding * matching_sample - input_embedding * matching_sample + 1)

這主要是為了背景,但我遇到的問題是如何訪問與輸入圖像對應的標簽值? 我需要能夠訪問這些標簽的值,以便采樣正確的向量並計算我的損失。

我在文檔中讀到你可以使用.eval()來通過在會話中運行來獲取張量的值,但是當我嘗試這個時我的終端只是掛了...技術上我在訓練時已經在運行一個會話我的神經網絡,所以不確定是否存在在另一個會話中運行第二個會話並嘗試評估技術上屬於另一個正在運行的會話的值的問題。 無論如何,我完全沒有關於如何使這項工作的想法。 任何幫助將不勝感激!

這是我原來的嘗試證明有問題:

# compute custom loss function using tensors and tensor operations
def compute_loss(e_list, labels, i):
    embedding = e_list[i] #getting the current embedding tensor
    label = labels[i] #getting the matching label tensor, this is value I need.
    y_index = np.nonzero(label)[0][0] #this always returns 0, doesn't work :(
    target = get_mnist_embedding(y_index)
    wrong_mnist = get_mismatch_mnist_embedding(y_index)
    wrong_spec = get_random_spec_embedding(y_index, e_list, labels)
    # compute the loss:
    zero = tf.constant(0,dtype="float32")
    one = tf.constant(1,dtype="float32")
    mul1 = tf.mul(wrong_mnist,embedding)
    dot1 = tf.reduce_sum(mul1)
    mul2 = tf.mul(target,embedding)
    dot2 = tf.reduce_sum(mul2)
    mul3 = tf.mul(target,wrong_spec)
    dot3 = tf.reduce_sum(mul3)
    max1 = tf.maximum(zero, tf.add_n([dot1, tf.negative(dot2), one]))
    max2 = tf.maximum(zero, tf.add_n([dot3, tf.negative(dot2), one]))
    loss = tf.add(max1,max2)
    return loss

問題是我沒有使用feed_dict傳遞采樣值。 我試圖在損失計算時識別輸入標簽的張量值,以便然后采樣我的損失函數所需的其他值。

我通過預先計算和組織mismatching_samplematching_sample值來解決這個問題,例如,最初使用tf.placeholder來表示這些值,並在執行sess.run([train_op, loss...], feed_dict=feed_dict)時使用feed_dict字典對象傳入值sess.run([train_op, loss...], feed_dict=feed_dict)訪問我的損失計算值。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM