簡體   English   中英

在 tensorflow 中出現 ValueError 說我的形狀不兼容

[英]Getting a ValueError in tensorflow saying that my shapes are incompatible

錯誤:

return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
    C:\Users\selvaa\miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\keras\backend.py:4619 categorical_crossentropy
        target.shape.assert_is_compatible_with(output.shape)
    C:\Users\selvaa\miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\tensor_shape.py:1128 assert_is_compatible_with
        raise ValueError("Shapes %s and %s are incompatible" % (self, other))

    ValueError: Shapes (None, 1) and (None, 151) are incompatible

我的 model:

x = np.array(x)
y = np.array(y)

x = x/255.0

model = Sequential()
model.add(Conv2D(3, (3,3), input_shape=(128,128,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))

model.add(Flatten())
model.add(Dense(302, activation='relu'))
model.add(Dense(151, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(x, y, batch_size=32, epochs=5, verbose=1, validation_split=0.1)

我正在嘗試訓練 model 來識別不同的口袋妖怪,我的數據集有兩張每個 151 個口袋妖怪的圖片(正確標記和全部)。 不知道我做錯了什么。

這是我打印 x.shape 和 y.shape 時發生的情況:

(301, 128, 128, 3) (301,)

使用損失tf.keras.losses.SparseCategoricalCrossEntropy ,如下面的代碼示例所示。

損失 function tf.keras.losses.SparseCategoricalCrossEntropy接受形狀(n_samples,)中的參考標簽和形狀(n_samples, n_classes)中的預測標簽,這將適用於您的數據。 您不能使用categorical_crossentropy ,因為它希望您的標簽是一次性編碼的(請參閱答案底部)。

x = np.array(x)
y = np.array(y)

x = x / 255.0

model = Sequential()
model.add(Conv2D(3, (3,3), input_shape=(128,128,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Flatten())
model.add(Dense(302, activation='relu'))
model.add(Dense(151, activation='softmax'))

model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(), 
    optimizer='adam', 
    metrics=['accuracy'])

model.fit(x, y, batch_size=32, epochs=5, verbose=1, validation_split=0.1)

另一種解決方案是在訓練之前對標簽進行一次熱編碼,例如使用 function tf.one_hot 如果您使用這種方法,那么您可以使用categorical_crossentropy

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM