[英]How to show Precession, Recall and F1-Score?
I am currently in the process of displaying precision, recall and fscore.我目前正在显示精度、召回率和 fscore。 Now my question is how do I do this?现在我的问题是我该怎么做? What I tried is the following:我尝试的是以下内容:
num_users, num_items = train_mat.shape
user_input, item_input, labels = get_train_samples(train_mat, num_negatives)
val_user_input, val_item_input, val_labels = get_train_samples(val_mat, num_negatives)
.
.
.
history = model.fit([np.array(user_input), np.array(item_input)], np.array(labels),
epochs=EPOCHS, verbose=VERBOSE, shuffle=True, batch_size = BATCH_SIZE,
validation_data=([np.array(val_user_input), np.array(val_item_input)], np.array(val_labels)),
callbacks=CALLBACKS)
.
.
.
# Precision, recall and fscore
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, roc_curve, auc
precision, recall, fscore, _ = precision_recall_fscore_support(y_test, y_pred, average='weighted')
print('Precision, recall, and F1 score, averaged and weighted by number of instances in each class:')
print('precision: {}'.format(precision))
print('recall: {}'.format(recall))
print('f1 score: {}\n'.format(fscore))
precision, recall, fscore, _ = precision_recall_fscore_support(y_test, y_pred)
print('Precision, recall, and F1 score, per class [0 1]:')
print('precision: {}'.format(precision))
print('recall: {}'.format(recall))
print('f1 score: {}'.format(fscore))
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True)
Unfortunately I don't know how to get y_test
and y_pred
.不幸的是,我不知道如何获得y_test
和y_pred
。 How do I get these values?我如何获得这些值?
you shall have y_test
as the test set to test your model and if you dont have such a set you can use sklearn train test split for getting a training set and a test set.你应该有y_test
作为测试集来测试你的模型,如果你没有这样的集,你可以使用 sklearn train test split 来获取训练集和测试集。 Here is the link for how to use it: sklearn traiin test split这是如何使用它的链接: sklearn train test split
and when you will have your test set you will do this to get y_pred
:当您拥有测试集时,您将这样做以获得y_pred
:
y_pred = model.predict(y_test)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.