For ~20,000 text datasets, the true and false samples are ~5,000 against ~1,5000. Two-channel textCNN built with Keras and Theano is used to do the classification. F1 score is the evaluation metric. The F1 score is not bad while the confusion matrix shows that the accuracy of the true samples is relatively low(~40%). But actually it is very important to predict the true samples accurately. Therefore, want to design a custom binary cross entropy loss function to increase the weight of mis-classified true samples and make the model focus more on predicting accurately on the true samples.
The sample code of the cost sensitive loss function focusing on the mis-classified samples is:
def w_categorical_crossentropy(y_true, y_pred, weights):
nb_cl = len(weights)
final_mask = K.zeros_like(y_pred[:, 0])
y_pred_max = K.max(y_pred, axis=1)
y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))
y_pred_max_mat = K.equal(y_pred, y_pred_max)
for c_p, c_t in product(range(nb_cl), range(nb_cl)):
final_mask += (weights[c_t, c_p] * y_pred_max_mat[:, c_p] * y_true[:, c_t])
return K.categorical_crossentropy(y_pred, y_true) * final_mask
Actually, a custom loss function for binary classification implemented with Keras and Theano that focuses on the mis-classified samples is of great importance to the imbalanced dataset. Please help troubleshoot this. Thanks!
Well when I have to deal with imbalanced datasets in keras, what I do is to first compute the weights for each class and pass them to the model instance during training. This will look something like this:
from sklearn.utils import compute_class_weight
w = compute_class_weight('balanced', np.unique(targets), targets)
# here I am adding only two categories with their corresponding weights
# you can spin a loop or continue by hand until you include all of your categories
weights = {
np.unique(targets)[0] : w[0], # class 0 with weight 0
np.unique(targets)[1] : w[1] # class 1 with weight 1
}
# then during training you do like this
model.fit(x=features, y=targets, {..}, class_weight=weights)
I believe this will solve your problem.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.