简体   繁体   English

从keras生成器获取真实标签

[英]get true labels from keras generator

I want to use sklearn.metrics.confusion_matrix(y_true, y_pred) to create a confusion matrix for a keras model. 我想使用sklearn.metrics.confusion_matrix(y_true, y_pred)为keras模型创建混淆矩阵。

After training a model I can use predict_generator(generator) to get predictions for a test dataset, which gives me y_pred . 训练模型后,我可以使用predict_generator(generator)来获得测试数据集的预测,这给了我y_pred How can I get the corresponding true labels, y_true from a data generator? 如何从数据生成器获取相应的true标签y_true

generator.classes will give you observed values in sparse format. generator.classes将为您提供稀疏格式的观察值。 You probably need it in dense (ie, one-hot encoded format). 您可能需要密集(即,一热编码格式)的文件。 You could get that with: 您可以通过以下方式获得:

import pandas as pd
pd.get_dummies(pd.Series(generator.classes)).to_dense()

NOTE though: you must set the generator's shuffle attribute to False before generating the predictions and fetching the observed classes, otherwise your predictions and observations will not line up! 但是请注意:在生成预测并获取观察到的类之前,必须将生成器的shuffle属性设置为False ,否则您的预测和观察将无法对齐!

After creating a data generator, either your own or the built in ImageDataGenerator , use your trained model to make predictions: 创建数据生成器(您自己的数据生成器或内置的ImageDataGenerator ,请使用受过训练的模型进行预测:

true_labels = data_generator.classes
predictions = model.predict_generator(data_generator)

sklearn's confusion matrix expects a 1-d array of labels, so you have to convert your predictions using np.argmax() sklearn的混淆矩阵需要一维标签数组,因此您必须使用np.argmax()转换预测

y_true = true_labels
y_pred = np.array([np.argmax(x) for x in predictions])

Then you can use those variables directly in the confusion_matrix function 然后,您可以直接在confusion_matrix函数中使用这些变量

cm = sklearn.metrics.confusion_matrix(y_true, y_pred)

And you can plot it using the example plot_confusion_matrix() function found here: 您可以使用此处的示例plot_confusion_matrix()函数对其进行绘制:

https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

在此处输入图片说明

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM