簡體   English   中英

ValueError:形狀不匹配:標簽的形狀(收到的 (1,))應該等於 logits 的形狀,除了最后一個維度(收到的 (10, 30))

[英]ValueError: Shape mismatch: The shape of labels (received (1,)) should equal the shape of logits except for the last dimension (received (10, 30))

我對 tensorflow 很陌生,非常感謝您的回答。 我正在嘗試使用變壓器 model 作為嵌入層並將數據提供給自定義 model。

from transformers import TFAutoModel
from tensorflow.keras import layers
def build_model():
    transformer_model = TFAutoModel.from_pretrained(MODEL_NAME, config=config)
    
    input_ids_in = layers.Input(shape=(MAX_LEN,), name='input_ids', dtype='int32')
    input_masks_in = layers.Input(shape=(MAX_LEN,), name='attention_mask', dtype='int32')

    embedding_layer = transformer_model(input_ids_in, attention_mask=input_masks_in)[0]

    X = layers.Bidirectional(tf.keras.layers.LSTM(50, return_sequences=True, dropout=0.1, recurrent_dropout=0.1))(embedding_layer)
    X = layers.GlobalMaxPool1D()(X)
    X = layers.Dense(64, activation='relu')(X)
    X = layers.Dropout(0.2)(X)
    X = layers.Dense(30, activation='softmax')(X)

    model = tf.keras.Model(inputs=[input_ids_in, input_masks_in], outputs = X)

    for layer in model.layers[:3]:
        layer.trainable = False

    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

    
model = build_model()
model.summary()
r = model.fit(
            train_ds,
            steps_per_epoch=train_steps,
            epochs=EPOCHS,
            verbose=3)

我有 30 個類,標簽不是單熱編碼的,所以我使用 sparse_categorical_crossentropy 作為我的損失 function 但我不斷收到以下錯誤

ValueError: Shape mismatch: The shape of labels (received (1,)) should equal the shape of logits except for the last dimension (received (10, 30)).

我該如何解決這個問題? 為什么需要 (10, 30) 形狀? 我知道 30 是因為最后一個 Dense 層有 30 個單位,但為什么是 10? 是因為 MAX_LENGTH 是 10 嗎?

我的 model 總結:

Model: "model_16"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_ids (InputLayer)          [(None, 10)]         0                                            
__________________________________________________________________________________________________
attention_mask (InputLayer)     [(None, 10)]         0                                            
__________________________________________________________________________________________________
tf_bert_model_21 (TFBertModel)  TFBaseModelOutputWit 162841344   input_ids[0][0]                  
                                                                 attention_mask[0][0]             
__________________________________________________________________________________________________
bidirectional_17 (Bidirectional (None, 10, 100)      327600      tf_bert_model_21[0][0]           
__________________________________________________________________________________________________
global_max_pooling1d_15 (Global (None, 100)          0           bidirectional_17[0][0]           
__________________________________________________________________________________________________
dense_32 (Dense)                (None, 64)           6464        global_max_pooling1d_15[0][0]    
__________________________________________________________________________________________________
dropout_867 (Dropout)           (None, 64)           0           dense_32[0][0]                   
__________________________________________________________________________________________________
dense_33 (Dense)                (None, 30)           1950        dropout_867[0][0]                
==================================================================================================
Total params: 163,177,358
Trainable params: 336,014
Non-trainable params: 162,841,344

10 是一批中的序列數。 我懷疑這是您的數據集中的許多序列。

您的 model 充當序列分類器。 因此,每個序列都應該有一個 label。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM