繁体   English   中英

logits 和标签必须是可广播的:logits_size=[384,2971] labels_size=[864,2971]

[英]logits and labels must be broadcastable: logits_size=[384,2971] labels_size=[864,2971]

我正在训练基于 RNN 的英语到印地语神经机器翻译 model。 我有一个 LSTM 层和注意力层。 我收到一个错误,即(0) Invalid argument: logits and labels must be broadcastable: logits_size=[384,2971] labels_size=[864,2971]

我的 model 总结是

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, None)]       0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, None)]       0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, None, 40)     93760       input_1[0][0]                    
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, None, 40)     118840      input_2[0][0]                    
__________________________________________________________________________________________________
conv1d (Conv1D)                 (None, None, 16)     7056        embedding[0][0]                  
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, None, 16)     10256       embedding_1[0][0]                
__________________________________________________________________________________________________
lstm (LSTM)                     [(None, 40), (None,  9120        conv1d[0][0]                     
__________________________________________________________________________________________________
lstm_1 (LSTM)                   [(None, None, 40), ( 9120        conv1d_1[0][0]                   
                                                                 lstm[0][1]                       
                                                                 lstm[0][2]                       
__________________________________________________________________________________________________
dense (Dense)                   (None, None, 2971)   121811      lstm_1[0][0]                     
==================================================================================================
Total params: 369,963
Trainable params: 369,963
Non-trainable params: 0
__________________________________________________________________________________________________

Model 编译代码为

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])
train_samples = len(X_train)
val_samples = len(X_test)
batch_size = 32
epochs = 100

适合是

history = model.fit_generator(generator = generate_batch(X_train, y_train, batch_size = batch_size),
                    steps_per_epoch = train_samples//batch_size,
                    epochs=epochs,
                    validation_data = generate_batch(X_test, y_test, batch_size = batch_size),
                    validation_steps = val_samples//batch_size)

错误是

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:1844: UserWarning: `Model.fit_generator` is deprecated and will be removed in a future version. Please use `Model.fit`, which supports generators.
  warnings.warn('`Model.fit_generator` is deprecated and '
Epoch 1/100
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-39-dc64566948be> in <module>()
      3                     epochs=epochs,
      4                     validation_data = generate_batch(X_test, y_test, batch_size = batch_size),
----> 5                     validation_steps = val_samples//batch_size)

7 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  logits and labels must be broadcastable: logits_size=[384,2971] labels_size=[864,2971]
     [[node categorical_crossentropy/softmax_cross_entropy_with_logits (defined at <ipython-input-39-dc64566948be>:5) ]]
     [[gradient_tape/model_1/embedding_3/embedding_lookup/Reshape/_56]]
  (1) Invalid argument:  logits and labels must be broadcastable: logits_size=[384,2971] labels_size=[864,2971]
     [[node categorical_crossentropy/softmax_cross_entropy_with_logits (defined at <ipython-input-39-dc64566948be>:5) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_11885]

Function call stack:
train_function -> train_function

您必须将嵌入正确连接到 CNN,将 CNN 正确连接到 LSTM

编码器:

num_encoder_tokens = 333
latent_dim = 128

encoder_inputs = Input(shape=(None,))
enc_emb =  Embedding(num_encoder_tokens, latent_dim, mask_zero = True)(encoder_inputs)
encoder_CNN = Conv1D(16, kernel_size=11, activation='relu')(enc_emb)
encoder_lstm = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder_lstm(encoder_CNN)
encoder_states = [state_h, state_c]

解码器:

num_decoder_tokens = 333

decoder_inputs = Input(shape=(None,))
dec_emb_layer = Embedding(num_decoder_tokens, latent_dim, mask_zero = True)
dec_emb = dec_emb_layer(decoder_inputs)
decoder_CNN = Conv1D(16, kernel_size=11, activation='relu')(dec_emb)
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_CNN, initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

概括:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None, None)]       0                                            
__________________________________________________________________________________________________
input_4 (InputLayer)            [(None, None)]       0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, None, 128)    42624       input_2[0][0]                    
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, None, 128)    42624       input_4[0][0]                    
__________________________________________________________________________________________________
conv1d (Conv1D)                 (None, None, 16)     22544       embedding[0][0]                  
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, None, 16)     22544       embedding_1[0][0]                
__________________________________________________________________________________________________
lstm (LSTM)                     [(None, 128), (None, 74240       conv1d[0][0]                     
__________________________________________________________________________________________________
lstm_1 (LSTM)                   [(None, None, 128),  74240       conv1d_1[0][0]                   
                                                                 lstm[0][1]                       
                                                                 lstm[0][2]                       
__________________________________________________________________________________________________
dense (Dense)                   (None, None, 333)    42957       lstm_1[0][0]                     
==================================================================================================

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM