I have a dataset of 50x22 which includes 22 features. The target is to classify the target which is scaled from 1 to 5, equivalently 5 classes. I used random forest with 98% accuracy but the validation is 63% which is not satisfiable. That's why I decided to create a deep model and I created a model with 3 layers. The result of loss is satisfiable around 6.7*10e-4 but the accuracy is fixed with zero. I think there is some thing wrong in my code. So, what's the problem?
def build_and_compile_model(norm):
model = keras.Sequential([
norm,
layers.Dense(32, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(1,activation='sigmoid')
])
model.compile(optimizer='sgd',
loss='binary_crossentropy',
metrics=[tf.keras.metrics.Accuracy()])
return model
def plot_acc(history):
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.ylim([0, 1])
plt.xlabel('Epoch')
plt.ylabel('Accuracy [GSR]')
plt.legend()
plt.grid(True)
dnn_qoe_model = build_and_compile_model(feature_normalizer)
dnn_qoe_model.summary()
history = dnn_qoe_model.fit(
train_features[:22], train_labels,
validation_split=0.2,
verbose=0, epochs=100)
plot_acc(history)
You are using loss='binary_crossentropy'
and layers.Dense(1,activation='sigmoid')
, which are used for binary classification problems.
Since you are looking to predict one of 5 classes, you are looking at a multi class problem.
If your target is one hot encoded which would look like so: [0,1,0,0,0]
for one class, you should use layers.Dense(5,activation='softmax')
and loss='categorical_crossentropy'
.
If your target isn't one hot encoded, which means the response is an integer referring to the class number, which would be [1]
(position of the positive class) in the previous example, you should use layers.Dense(5,activation='softmax')
, and change the loss function to loss='sparse_categorical_crossentropy'
, as your target variable is encoded as a sparse vector (refers to the index of the item containing a 1 in a vector of zeros)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.