繁体   English   中英

ValueError:形状不匹配:标签的形状(收到的 (1,))应该等于 logits 的形状,除了最后一个维度(收到的 (10, 30))

[英]ValueError: Shape mismatch: The shape of labels (received (1,)) should equal the shape of logits except for the last dimension (received (10, 30))

我对 tensorflow 很陌生,非常感谢您的回答。 我正在尝试使用变压器 model 作为嵌入层并将数据提供给自定义 model。

from transformers import TFAutoModel
from tensorflow.keras import layers
def build_model():
    transformer_model = TFAutoModel.from_pretrained(MODEL_NAME, config=config)
    
    input_ids_in = layers.Input(shape=(MAX_LEN,), name='input_ids', dtype='int32')
    input_masks_in = layers.Input(shape=(MAX_LEN,), name='attention_mask', dtype='int32')

    embedding_layer = transformer_model(input_ids_in, attention_mask=input_masks_in)[0]

    X = layers.Bidirectional(tf.keras.layers.LSTM(50, return_sequences=True, dropout=0.1, recurrent_dropout=0.1))(embedding_layer)
    X = layers.GlobalMaxPool1D()(X)
    X = layers.Dense(64, activation='relu')(X)
    X = layers.Dropout(0.2)(X)
    X = layers.Dense(30, activation='softmax')(X)

    model = tf.keras.Model(inputs=[input_ids_in, input_masks_in], outputs = X)

    for layer in model.layers[:3]:
        layer.trainable = False

    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

    
model = build_model()
model.summary()
r = model.fit(
            train_ds,
            steps_per_epoch=train_steps,
            epochs=EPOCHS,
            verbose=3)

我有 30 个类,标签不是单热编码的,所以我使用 sparse_categorical_crossentropy 作为我的损失 function 但我不断收到以下错误

ValueError: Shape mismatch: The shape of labels (received (1,)) should equal the shape of logits except for the last dimension (received (10, 30)).

我该如何解决这个问题? 为什么需要 (10, 30) 形状? 我知道 30 是因为最后一个 Dense 层有 30 个单位,但为什么是 10? 是因为 MAX_LENGTH 是 10 吗?

我的 model 总结:

Model: "model_16"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_ids (InputLayer)          [(None, 10)]         0                                            
__________________________________________________________________________________________________
attention_mask (InputLayer)     [(None, 10)]         0                                            
__________________________________________________________________________________________________
tf_bert_model_21 (TFBertModel)  TFBaseModelOutputWit 162841344   input_ids[0][0]                  
                                                                 attention_mask[0][0]             
__________________________________________________________________________________________________
bidirectional_17 (Bidirectional (None, 10, 100)      327600      tf_bert_model_21[0][0]           
__________________________________________________________________________________________________
global_max_pooling1d_15 (Global (None, 100)          0           bidirectional_17[0][0]           
__________________________________________________________________________________________________
dense_32 (Dense)                (None, 64)           6464        global_max_pooling1d_15[0][0]    
__________________________________________________________________________________________________
dropout_867 (Dropout)           (None, 64)           0           dense_32[0][0]                   
__________________________________________________________________________________________________
dense_33 (Dense)                (None, 30)           1950        dropout_867[0][0]                
==================================================================================================
Total params: 163,177,358
Trainable params: 336,014
Non-trainable params: 162,841,344

10 是一批中的序列数。 我怀疑这是您的数据集中的许多序列。

您的 model 充当序列分类器。 因此,每个序列都应该有一个 label。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM