簡體   English   中英

keras 中的 CTC 損失實現

[英]CTC loss implementation in keras

我正在嘗試使用 keras 為我的簡化神經網絡實現 CTC 損失:

  
def ctc_lambda_func(args):
    y_pred, y_train, input_length, label_length = args
 
    return K.ctc_batch_cost(y_train, y_pred, input_length, label_length)


x_train = x_train.reshape(x_train.shape[0],20, 10).astype('float32')

input_data = layers.Input(shape=(20,10,))
x=layers.Convolution1D(filters=256, kernel_size=3,  padding="same", strides=1, use_bias=False ,activation= 'relu')(input_data)
x=layers.BatchNormalization()(x)
x=layers.Dropout(0.2)(x)

x=layers.Bidirectional (LSTM(units=200 , return_sequences=True)) (x)
x=layers.BatchNormalization()(x)
x=layers.Dropout(0.2)(x)


y_pred=outputs = layers.Dense(5, activation='softmax')(x)
fun = Model(input_data, y_pred)
# fun.summary()

label_length=np.zeros((3800,1))
input_length=np.zeros((3800,1))

for i in range (3799):
    label_length[i,0]=4
    input_length[i,0]=5 
  
y_train = np.array(y_train)
x_train = np.array(x_train)
input_length = np.array(input_length)
label_length = np.array(label_length) 

  
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, y_train, input_length, label_length])
model =keras.models.Model(inputs=[input_data, y_train, input_length, label_length], outputs=loss_out)
model.compile(loss={'ctc': lambda y_train, y_pred: y_pred}, optimizer = 'adam')
model.fit(x=[x_train, y_train, input_length, label_length],  epochs=10, batch_size=100)

我們有 (3800,4) 維的 y_true(或 y_train),因此我把 label_length=4 和 input_length=5(+1 表示空白)

我面臨這個錯誤:

ValueError: Input tensors to a Model must come from `tf.keras.Input`. Received: [[0. 1. 0. 0.]
 [0. 1. 0. 0.]
 [0. 1. 0. 0.]
 ...
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]] (missing previous layer metadata).

y_true 是這樣的:

 [[0. 1. 0. 0.]
 [0. 1. 0. 0.]
 ...
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]]

我的問題是什么?

你誤解了長度。 它不是標簽類別的數量,而是序列的實際長度。 CTC 只能用於目標符號數小於輸入狀態數的情況。 從技術上講,輸入和輸出的數量是相同的,但有些輸出是空白的。 (這通常發生在語音識別中,在這種情況下,您有大量輸入信號窗口,而輸出中的 foneme 卻很少。)

假設您必須填充輸入和輸出以將它們批量處理:

  • input_length應該包含批處理中的每個項目,有多少輸入是實際有效的,即沒有填充;

  • label_length應包含模型應為批次中的每個項目生成多少個非空白標簽。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM