繁体   English   中英

PyTorch CNN教程的混淆矩阵和测试准确率

[英]Confusion matrix and test accuracy for PyTorch CNN tutorial

我有兴趣只报告训练和测试的准确性以及混淆矩阵(比如使用 sklearn 混淆矩阵)。 我怎样才能做到这一点? 当前教程仅报告训练/验证准确性,我很难弄清楚如何在其中合并 sklearn 混淆矩阵代码。 此处链接到原始教程: https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/4%20-%20Convolutional%20Sentiment%20Analysis.ipynb

与教程中定义的binary_accuracy function 非常相似,您可以实现任何您想要的指标。 您只需要一组preds预测(在本例中为 preds )和真实目标( y )。

例如,对于混淆矩阵,您可以执行以下操作:

from sklearn.metrics import confusion_matrix

def compute_confusion_matrix(preds, y):
    #round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds))
    return confusion_matrix(y, rounded_preds)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM