簡體   English   中英

使用 Keras 在自定義損失函數中索引張量

[英]Indexing tensors in custom loss function with Keras

我在Keras使用自定義損失函數。 這是函數:

def custom_loss(groups_id_count):
  def listnet_loss(real_labels, predicted_labels):
    losses = tf.placeholder(shape=[None], dtype=tf.float32) # Tensor of rank 1
    for group in groups_id_count:
      start_range = 0
      end_range = (start_range + group[1])
      batch_real_labels = tf.slice(real_labels, [start_range, 1, None], [end_range, 1, None])
      batch_predicted_labels = tf.slice(predicted_labels, [start_range, 0, 0], [end_range, 0, 0])
      loss = -K.sum(get_top_one_probability(batch_real_labels)) * tf.math.log(get_top_one_probability(batch_predicted_labels))
      losses = tf.concat([losses, loss], axis=0)
      start_range = end_range
    return K.mean(losses)
  return listnet_loss

我會得到real_labelspredicted_labels項目從start_rangeend_range ,但目前的代碼返回一個例外:

錯誤:

TypeError: Failed to convert object of type <class 'list'> to Tensor.
Contents: [0, 1, None]. Consider casting elements to a supported type.

我不知道該怎么辦,因為這是我第一次使用TensorFlowKeras 如何使用張量索引獲取項目? 提前致謝。

錯誤是因為在tf.placeholder指定了None shape ,並且它發生在行中,

  batch_real_labels = tf.slice(real_labels, [start_range, 1, None], [end_range, 1, None])

解決方案是將variable定義為該placeholdershape ,並使用該variable而不是None

相同的代碼如下所示:

h = tf.shape(losses)[0]
batch_real_labels = tf.slice(real_labels, [start_range, 1, h], [end_range, 1, h])

此解決方法將修復錯誤TypeError: Failed to convert object of type <class 'list'> to Tensor. Contents: [0, 1, None]. Consider casting elements to a supported type. TypeError: Failed to convert object of type <class 'list'> to Tensor. Contents: [0, 1, None]. Consider casting elements to a supported type. 但隨后的代碼行可能會導致其他錯誤。

如果您遇到任何其他錯誤,請分享error ,包括函數、 get_top_one_probability和目標的完整代碼,您究竟想使用該函數實現什么, custom_loss ,我很樂意為您提供幫助。

作為新TensorflowKeras ,我希望你快樂學習!

請使用后端函數K.reshape根據自己的知識重塑輸入標簽和輸入預測。

對於標簽,輸入是未定義的(?,?),因此您需要通過重塑它來修復它。 否則,您無法對其進行索引。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM