簡體   English   中英

SparseCategoricalCrossentropy 形狀不匹配

[英]SparseCategoricalCrossentropy Shape Mismatch

我想對 SparseCategoricalCrossentropy 函數做一個簡單的測試,看看它到底對輸出做了什么。 為此,我使用 MobileNetV2 的最后一層的輸出。

    import keras.backend as K

    full_model = tf.keras.applications.MobileNetV2(
    input_shape=(224,224,3),
    alpha=1.0,
    include_top=True,
    weights="imagenet",
    input_tensor=None,
    pooling=None,
    classes=1000,
    classifier_activation="softmax",)

    func = K.function(full_model.layers[1].input, full_model.layers[155].output)
    conv_output = func([processed_image])
    y_pred = np.single(conv_output)
    
    y_true = np.zeros(1000).reshape(1,1000)
    y_true[0][282] = 1
    
    scce = tf.keras.losses.SparseCategoricalCrossentropy()
    scce(y_true, y_pred).numpy()

processed_image是先前創建的1x224x224x3陣列。

我收到錯誤ValueError: Shape mismatch: The shape of labels (received (1000,)) should equal the shape of logits except for the last dimension (received (1, 1000)).

我嘗試重塑數組以匹配提到的錯誤的維度,但它似乎不起作用。 它接受什么形狀?

由於您使用的是SparseCategoricalCrossentropy損失函數,因此y_true的形狀應為[batch_size] ,而y_pred的形狀應為[batch_size, num_classes] 此外, y_true應該由整數值組成。 請參閱文檔 在你的具體例子中,你可以嘗試這樣的事情:

import keras.backend as K
import tensorflow as tf
import numpy as np

full_model = tf.keras.applications.MobileNetV2(
             input_shape=(224,224,3),
             alpha=1.0,
             include_top=True,
             weights="imagenet",
             input_tensor=None,
             pooling=None,
             classes=1000,
             classifier_activation="softmax",)

batch_size = 1
processed_image = tf.random.uniform(shape=[batch_size,224,224,3])
func = K.function(full_model.layers[1].input, 
full_model.layers[155].output)
conv_output = func([processed_image])
y_pred = np.single(conv_output)

# Generates an integer between 0 and 999 representing a class index.
y_true = np.random.randint(low = 0, high = 999, size = batch_size)
# [984]
scce = tf.keras.losses.SparseCategoricalCrossentropy() 
scce(y_true, y_pred).numpy()
# y_pred encodes a probability distribution here and the calculated loss is 10.69202

您可以嘗試使用batch_size來查看一切是如何工作的。 在上面的例子中,我只使用了 1 的batch_size

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM