简体   繁体   中英

Keras CNN, High training while low testing

Iam doing a text classification, my dataset size is 16000 KB, my problem is I have 95% of training and 90% in testing.. can I increase testing? and how?

here is my code

model = Sequential()
model.add(Conv1D( filters=256,kernel_size=5, activation = 'relu',input_shape=(7,1)))
model.add(layers.GlobalMaxPooling1D())
model.add(layers.Dense(128, activation='relu'))

model.add(layers.Dense(64, activation='relu'))

model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))

model.add(Dense(11, activation='softmax'))
model.summary()
model.compile(Adam(lr=0.001),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
history = model.fit(X_train, y_train,
                    epochs=200,
                    verbose=True,
                    validation_data=(X_test, y_test),
                    batch_size=128)
loss, accuracy = model.evaluate(X_train, y_train, verbose=True)
print("Training Accuracy: {:.4f}".format(accuracy))
loss, accuracy = model.evaluate(X_test, y_test, verbose=False)
print("Testing Accuracy:  {:.4f}".format(accuracy))

The first step to debug the model is to plot the training validation curve like the example.

Typical training validation curve

Now based on how the curves behave there can be below possible inferences and solutions.

  1. The two curves diverge as the model is trained, training keeps on improving while the testing either gets worse or saturates way earlier than training.

    Cause: Model is overfitting the training and needs regularisation eg. dropout, weight decay, etc.

  2. The two curves stick close together at the end and no further improvements happen.

    Cause: Model is saturated or stuck in local minima, try increasing the learning rate to push out of minima, if still no major improvements, try adding more complexity to the model.

  3. The two curves have saturated at the end, but are a small distance apart, and not major changes happed as further trained.

    Cause: the model has learned what it could from the available data and will not improve any further, try data transformations to generate new data or get more data.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM