简体   繁体   中英

Error in keras sparse_categorical_crossentropy loss function

I am trying a deep neural network prediction but getting error:

InvalidArgumentError: logits and labels must have the same first dimension, got logits shape [32,4] and labels shape [128]

Here are the features:

new_features.shape
(19973, 8)

new_features[0].shape
(8,)

Here are the label/output

output.shape
(19973, 4)

output[0].shape
(4,)

Here is the keras code

model = Sequential(
  [
    Dense(units=8, input_shape=new_features[0].shape, name="layer1"),
    Dense(units=1024, activation="relu", name="layer2"),
    Dense(units=1024, activation="relu", name="layer3"),
    Dense(units=4,  name="layer4", activation="softmax"),
  ]
)

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
model.fit(new_features, output, epochs=2)

The features and labels contain float values.

The problem is in your target shape. First of all your target in classification problems must be int

if you have 1D integer encoded target you can use sparse_categorical_crossentropy as loss function

X = np.random.randint(0,10, (1000,100))
y = np.random.randint(0,3, 1000)

model = Sequential([
    Dense(128, input_dim = 100),
    Dense(3, activation='softmax'),
])
model.summary()
model.compile(loss='sparse_categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
history = model.fit(X, y, epochs=3)

Otherwise, if you have one-hot encoded your target in order to have 2D shape (n_samples, n_class) you can use categorical_crossentropy

X = np.random.randint(0,10, (1000,100))
y = pd.get_dummies(np.random.randint(0,3, 1000)).values

model = Sequential([
    Dense(128, input_dim = 100),
    Dense(3, activation='softmax'),
])
model.summary()
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
history = model.fit(X, y, epochs=3)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM