简体   繁体   中英

Evaluating a Multi-Label Classification model

I currently have a multi-label classification problem, for which I am using keras to build a neural network as follows:

n_cols = dataset.shape[1]
print(n_cols)

model = Sequential()
model.add(Dense(128, activation='relu', input_shape=(n_cols,)))
model.add(Dense(64, activation='relu'))
model.add(Dense(26, activation='sigmoid')) # Sigmoid for multi-label classification

sgd = SGD(lr=0.1, decay=1e-6, momentum=0.5, nesterov=True)
#RMSprop
model.compile(loss='binary_crossentropy', optimizer='RMSprop', metrics=['accuracy'])

model.summary()

## Fit the model ##
early_stopping_monitor = EarlyStopping(patience=20)
history = model.fit(dataset, labels, validation_split=0.33, epochs=30, callbacks=[early_stopping_monitor])

plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

I was informed that for multi-label classification, we use binary_crossentropy for the loss while having sigmoid for activation in the final layer (output layer). However, with this I am getting a resulting accuracy and val_accuracy of ~0.0931 and ~0.0937 respectively.

For the multi-label classification, is using the accuracy metric the best fit? I've looked around and some suggest that other metrics such as binary_accuracy may be better..

So the question is, how can one best evaluate the multi-label classification?

EDIT: For reference, I have 26 label columns in my target "classes" and the dataset consist of 21 columns. The entire dataset the model is trained on has ~82k samples.

As mentioned in this post it is correct to use binary_crossentropy. Adding a softmax layer as a last layer is a first change one could do to your model as noted in the answer below in that thread. Furthermore regarding your actual question: As mentioned in the linked post you might want to go for precision/recall instead.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM