[英]How to plot a prediction of a model?
I have output of my model y_pred
and x_test
.我有 model
y_pred
和 x_test 的x_test
。 How can I easily see the results of it by plotting with plt?如何通过 plt 绘图轻松查看结果?
here I am plotting results in 2 rows:在这里,我将结果绘制成 2 行:
def visualiseResult(y_pred, x_test):
fig, ax = plt.subplots(2, 5)
for i in range(0, 10, 1):
y_pred_base = y_pred[0][i]
y_pred_aux = y_pred[1][i]
y_pred_base = y_pred_base.reshape(48,48)
y_pred_aux = y_pred_aux.reshape(48,48)
# visualizeImage(y_pred_base)
# visualizeImage(y_pred_aux)
x1 = y_pred_base.reshape(48*48)
y1 = y_pred_aux.reshape(48*48)
# print(list(x1).index(np.max(x1)))
# print(list(y1).index(np.max(y1)))
x1 = list(x1).index(np.max(x1))
x = round(x1/47)
y = round(x1%48)
row = int(i / 5)
col = round(i%5)
ax[row][col].imshow(x_test[i], cmap = 'gray')
ax[row][col].scatter(x, y)
fig.show()
Your question is vague, but I will try to help:您的问题含糊不清,但我会尽力提供帮助:
Here is an example for training history after fitting the model to see overfitting/underfitting这是拟合 model 以查看过拟合/欠拟合后的训练历史示例
from matplotlib import pyplot as plt
date = datetime.now().strftime("%Y_%m_%d-%I:%M:%S_%p")
history = model.fit(X_train, y_train,
epochs=200,
batch_size=16,
validation_data=(X_val, y_val),
callbacks = [checkpoint,earlystopping])
# Visualize the training history to see whether you're overfitting.
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['TRAIN', 'VAL'], loc='lower right')
plt.savefig('saved_charts/training_history' + str(date) + '.png')
plt.show()
Since you are doing prediction/test, I am guessing you are trying to plot the accuracy results for that.由于您正在进行预测/测试,我猜您正在尝试 plot 的准确度结果。 I recommend a confusion matrix:
我推荐一个混淆矩阵:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from matplotlib import pyplot as plt
from matplotlib.collections import LineCollection
import numpy as np
def plot_confusion_matrix(cm, classes, normalize = False, title='Confusion matrix', cmap=plt.cm.Blues):
#plots confusion matrix
date = datetime.now().strftime("%Y_%m_%d-%I:%M:%S_%p")
if normalize:
cm = cm.astype('float') / cm.sum(axis =1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=55)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max()/2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i,j], fmt), horizontalalignment="center", color="white" if cm[i,j] > thresh else "black")
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig('saved_charts/confusion_matrix_' + str(date) + '.png')
y_pred = model.predict(X_test)
y_pred_label = [class_names[i] for i in np.argmax(y_pred, axis=1)]
y_true_label = [class_names[i] for i in np.argmax(y_test, axis=1)]
#plot the confusion matrix
cm = confusion_matrix(np.argmax(y_test, axis=1), np.argmax(y_pred, axis=1))
plot_confusion_matrix(cm, class_names, title = 'Confusion Matrix of Model')
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.