简体   繁体   中英

How to apply model.fit() function over an CNN-LSTM model?

I am trying to use this to classify the images into two categories. Also I applied model.fit() function but its showing error.

ValueError: A target array with shape (90, 1) was passed for an output of shape (None, 10) while using as loss binary_crossentropy. This loss expects targets to have the same shape as the output.

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten, Conv2D, MaxPooling2D, LSTM
import pickle
import numpy as np

X = np.array(pickle.load(open("X.pickle","rb")))
Y = np.array(pickle.load(open("Y.pickle","rb")))

#scaling our image data
X = X/255.0
model = Sequential()

model.add(Conv2D(64 ,(3,3), input_shape = (300,300,1)))

# model.add(MaxPooling2D(pool_size = (2,2)))

model.add(tf.keras.layers.Reshape((16, 16*512)))
model.add(LSTM(128, activation='relu', return_sequences=True))
model.add(Dropout(0.2))

model.add(LSTM(128, activation='relu'))
model.add(Dropout(0.2))

model.add(Dense(32, activation='relu'))
model.add(Dropout(0.2))

model.add(Dense(10, activation='softmax'))

opt = tf.keras.optimizers.Adam(lr=1e-3, decay=1e-5)


model.compile(loss='binary_crossentropy', optimizer=opt,
             metrics=['accuracy'])

# model.summary()
model.fit(X, Y, batch_size=32, epochs = 2, validation_split=0.1)

If your problem is categorical, your issue is that you are using binary_crossentropy instead of categorical_crossentropy ; ensure that you do have a categorical instead of a binary classification problem.

Also, please note that if your labels are in simple integer format like [1,2,3,4...] and not one-hot-encoded, your loss_function should be sparse_categorical_crossentropy , not categorical_crossentropy .

If you do have a binary classification problem, like said in the error of the above ensure that:

  1. Loss is binary_crossentroy + Dense(1,activation='sigmoid')
  2. Loss is categorical_crossentropy + Dense(2,activation='softmax')

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM