簡體   English   中英

運行 CTC 損失 function

[英]Running CTC loss function

我想在莎士比亞數據集上嘗試 CTC 損失 function,在計算損失期間,預測的張量形狀為 (64, 100, 65),這與 (64, 100) 的 label 形狀不匹配。所以我使用了一些數學來轉換維度但有一個錯誤。

代碼

def loss(labels, logits):
  return tf.keras.losses.categorical_crossentropy(labels, logits)

example_batch_loss  = loss(labels=target_example_batch, logits=tf.math.argmax(tf.convert_to_tensor(example_batch_predictions), axis=-1, output_type=tf.int64))

錯誤

無法計算 Mul,因為輸入 #1(從零開始)應該是一個 int64 張量,但它是一個雙張量 [Op:Mul]

請幫我找到一個使用 CTC loss 的解決方案。

您正在提供 model output 的 argmax,即 output 具有最高值的索引。 CTC 損失(就像大多數損失函數一樣)與 logits 一起工作,logits 是由 model 產生的非標准化概率分布。因此,預測形狀 (64, 100, 65) 和目標只有 (64, 100) 沒有錯.

但是請注意,CTC 旨在處理 model output 比目標長得多的情況。 典型的用例是語音識別,其中有大量信號 windows 與相對較少的音素匹配。 如果你的 output 長度和目標長度相同,則 CTC 退化為標准交叉熵。

假設example_batch_predictions是您的 model output 在通過 softmax 對其進行歸一化之前,您應該這樣做:

example_batch_loss  = loss(labels=target_example_batch, logits=example_batch_predictions, axis=-1, output_type=tf.int64))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM