繁体   English   中英

R 如何使用 caret 包可视化混淆矩阵

[英]R how to visualize confusion matrix using the caret package

我想可视化我放入混淆矩阵的数据。 有没有一个函数我可以简单地放置混淆矩阵,它可以将它可视化(绘制它)?

例如我想做的事情(Matrix$nnet 只是一个包含分类结果的表格):

Confusion$nnet <- confusionMatrix(Matrix$nnet)
plot(Confusion$nnet)

我的 Confusion$nnet$table 看起来像这样:

    prediction (I would also like to get rid of this string, any help?)
    1  2
1   42 6
2   8 28

您可以仅使用 r 中的 rect 功能来布局混淆矩阵。 这里我们将创建一个函数,该函数允许用户传入由 caret 包创建的 cm 对象以产生视觉效果。

让我们首先创建一个评估数据集,如插入符号演示中所做的那样

# construct the evaluation dataset
set.seed(144)
true_class <- factor(sample(paste0("Class", 1:2), size = 1000, prob = c(.2, .8), replace = TRUE))
true_class <- sort(true_class)
class1_probs <- rbeta(sum(true_class == "Class1"), 4, 1)
class2_probs <- rbeta(sum(true_class == "Class2"), 1, 2.5)
test_set <- data.frame(obs = true_class,Class1 = c(class1_probs, class2_probs))
test_set$Class2 <- 1 - test_set$Class1
test_set$pred <- factor(ifelse(test_set$Class1 >= .5, "Class1", "Class2"))

现在让我们使用插入符号来计算混淆矩阵

# calculate the confusion matrix
cm <- confusionMatrix(data = test_set$pred, reference = test_set$obs)

现在我们创建一个函数,根据需要布置矩形,以更具视觉吸引力的方式展示混淆矩阵

draw_confusion_matrix <- function(cm) {

  layout(matrix(c(1,1,2)))
  par(mar=c(2,2,2,2))
  plot(c(100, 345), c(300, 450), type = "n", xlab="", ylab="", xaxt='n', yaxt='n')
  title('CONFUSION MATRIX', cex.main=2)

  # create the matrix 
  rect(150, 430, 240, 370, col='#3F97D0')
  text(195, 435, 'Class1', cex=1.2)
  rect(250, 430, 340, 370, col='#F7AD50')
  text(295, 435, 'Class2', cex=1.2)
  text(125, 370, 'Predicted', cex=1.3, srt=90, font=2)
  text(245, 450, 'Actual', cex=1.3, font=2)
  rect(150, 305, 240, 365, col='#F7AD50')
  rect(250, 305, 340, 365, col='#3F97D0')
  text(140, 400, 'Class1', cex=1.2, srt=90)
  text(140, 335, 'Class2', cex=1.2, srt=90)

  # add in the cm results 
  res <- as.numeric(cm$table)
  text(195, 400, res[1], cex=1.6, font=2, col='white')
  text(195, 335, res[2], cex=1.6, font=2, col='white')
  text(295, 400, res[3], cex=1.6, font=2, col='white')
  text(295, 335, res[4], cex=1.6, font=2, col='white')

  # add in the specifics 
  plot(c(100, 0), c(100, 0), type = "n", xlab="", ylab="", main = "DETAILS", xaxt='n', yaxt='n')
  text(10, 85, names(cm$byClass[1]), cex=1.2, font=2)
  text(10, 70, round(as.numeric(cm$byClass[1]), 3), cex=1.2)
  text(30, 85, names(cm$byClass[2]), cex=1.2, font=2)
  text(30, 70, round(as.numeric(cm$byClass[2]), 3), cex=1.2)
  text(50, 85, names(cm$byClass[5]), cex=1.2, font=2)
  text(50, 70, round(as.numeric(cm$byClass[5]), 3), cex=1.2)
  text(70, 85, names(cm$byClass[6]), cex=1.2, font=2)
  text(70, 70, round(as.numeric(cm$byClass[6]), 3), cex=1.2)
  text(90, 85, names(cm$byClass[7]), cex=1.2, font=2)
  text(90, 70, round(as.numeric(cm$byClass[7]), 3), cex=1.2)

  # add in the accuracy information 
  text(30, 35, names(cm$overall[1]), cex=1.5, font=2)
  text(30, 20, round(as.numeric(cm$overall[1]), 3), cex=1.4)
  text(70, 35, names(cm$overall[2]), cex=1.5, font=2)
  text(70, 20, round(as.numeric(cm$overall[2]), 3), cex=1.4)
}  

最后,传入我们在使用插入符号创建混淆矩阵时计算的cm 对象

draw_confusion_matrix(cm)

结果如下:

来自 caret 包的混淆矩阵的可视化

您可以使用内置的fourfoldplot 例如,

ctable <- as.table(matrix(c(42, 6, 8, 28), nrow = 2, byrow = TRUE))
fourfoldplot(ctable, color = c("#CC6666", "#99CC99"),
             conf.level = 0, margin = 1, main = "Confusion Matrix")

在此处输入图片说明

您可以使用函数conf_mat()yardstick加上autoplot()在几行得到一个相当不错的结果。

另外,您仍然可以使用基本的ggplot来修复样式。

library(yardstick)
library(ggplot2)


# The confusion matrix from a single assessment set (i.e. fold)
cm <- conf_mat(truth_predicted, obs, pred)

autoplot(cm, type = "heatmap") +
  scale_fill_gradient(low="#D6EAF8",high = "#2E86C1")

在此处输入图片说明


作为进一步自定义的示例,使用ggplot您还可以添加回图例:

+ theme(legend.position = "right")

更改图例的名称也很容易: + labs(fill="legend_name")

在此处输入图片说明

数据示例:

set.seed(123)
truth_predicted <- data.frame(
  obs = sample(0:1,100, replace = T),
  pred = sample(0:1,100, replace = T)
)
truth_predicted$obs <- as.factor(truth_predicted$obs)
truth_predicted$pred <- as.factor(truth_predicted$pred)

我真的很喜欢@Cyber​​netic 漂亮的混淆矩阵可视化,并做了两个调整以希望进一步改进它。

1) 我用类的实际值交换了 Class1 和 Class2。 2)我用基于百分位数生成红色(未命中)和绿色(命中)的函数替换橙色和蓝色。 我们的想法是快速查看问题/成功的位置及其大小。

截图和代码:

混淆矩阵代码更新

draw_confusion_matrix <- function(cm) {

  total <- sum(cm$table)
  res <- as.numeric(cm$table)

  # Generate color gradients. Palettes come from RColorBrewer.
  greenPalette <- c("#F7FCF5","#E5F5E0","#C7E9C0","#A1D99B","#74C476","#41AB5D","#238B45","#006D2C","#00441B")
  redPalette <- c("#FFF5F0","#FEE0D2","#FCBBA1","#FC9272","#FB6A4A","#EF3B2C","#CB181D","#A50F15","#67000D")
  getColor <- function (greenOrRed = "green", amount = 0) {
    if (amount == 0)
      return("#FFFFFF")
    palette <- greenPalette
    if (greenOrRed == "red")
      palette <- redPalette
    colorRampPalette(palette)(100)[10 + ceiling(90 * amount / total)]
  }

  # set the basic layout
  layout(matrix(c(1,1,2)))
  par(mar=c(2,2,2,2))
  plot(c(100, 345), c(300, 450), type = "n", xlab="", ylab="", xaxt='n', yaxt='n')
  title('CONFUSION MATRIX', cex.main=2)

  # create the matrix 
  classes = colnames(cm$table)
  rect(150, 430, 240, 370, col=getColor("green", res[1]))
  text(195, 435, classes[1], cex=1.2)
  rect(250, 430, 340, 370, col=getColor("red", res[3]))
  text(295, 435, classes[2], cex=1.2)
  text(125, 370, 'Predicted', cex=1.3, srt=90, font=2)
  text(245, 450, 'Actual', cex=1.3, font=2)
  rect(150, 305, 240, 365, col=getColor("red", res[2]))
  rect(250, 305, 340, 365, col=getColor("green", res[4]))
  text(140, 400, classes[1], cex=1.2, srt=90)
  text(140, 335, classes[2], cex=1.2, srt=90)

  # add in the cm results
  text(195, 400, res[1], cex=1.6, font=2, col='white')
  text(195, 335, res[2], cex=1.6, font=2, col='white')
  text(295, 400, res[3], cex=1.6, font=2, col='white')
  text(295, 335, res[4], cex=1.6, font=2, col='white')

  # add in the specifics 
  plot(c(100, 0), c(100, 0), type = "n", xlab="", ylab="", main = "DETAILS", xaxt='n', yaxt='n')
  text(10, 85, names(cm$byClass[1]), cex=1.2, font=2)
  text(10, 70, round(as.numeric(cm$byClass[1]), 3), cex=1.2)
  text(30, 85, names(cm$byClass[2]), cex=1.2, font=2)
  text(30, 70, round(as.numeric(cm$byClass[2]), 3), cex=1.2)
  text(50, 85, names(cm$byClass[5]), cex=1.2, font=2)
  text(50, 70, round(as.numeric(cm$byClass[5]), 3), cex=1.2)
  text(70, 85, names(cm$byClass[6]), cex=1.2, font=2)
  text(70, 70, round(as.numeric(cm$byClass[6]), 3), cex=1.2)
  text(90, 85, names(cm$byClass[7]), cex=1.2, font=2)
  text(90, 70, round(as.numeric(cm$byClass[7]), 3), cex=1.2)

  # add in the accuracy information 
  text(30, 35, names(cm$overall[1]), cex=1.5, font=2)
  text(30, 20, round(as.numeric(cm$overall[1]), 3), cex=1.4)
  text(70, 35, names(cm$overall[2]), cex=1.5, font=2)
  text(70, 20, round(as.numeric(cm$overall[2]), 3), cex=1.4)
}

这是一个简单的基于ggplot2的想法,可以根据需要进行更改,我正在使用此链接中的数据:

#data
confusionMatrix(iris$Species, sample(iris$Species))
newPrior <- c(.05, .8, .15)
names(newPrior) <- levels(iris$Species)

cm <-confusionMatrix(iris$Species, sample(iris$Species))

现在 cm 是一个混淆矩阵对象,可以取出一些对问题有用的东西:

# extract the confusion matrix values as data.frame
cm_d <- as.data.frame(cm$table)
# confusion matrix statistics as data.frame
cm_st <-data.frame(cm$overall)
# round the values
cm_st$cm.overall <- round(cm_st$cm.overall,2)

# here we also have the rounded percentage values
cm_p <- as.data.frame(prop.table(cm$table))
cm_d$Perc <- round(cm_p$Freq*100,2)

现在我们准备好绘制:

library(ggplot2)     # to plot
library(gridExtra)   # to put more
library(grid)        # plot together

# plotting the matrix
cm_d_p <-  ggplot(data = cm_d, aes(x = Prediction , y =  Reference, fill = Freq))+
  geom_tile() +
  geom_text(aes(label = paste("",Freq,",",Perc,"%")), color = 'red', size = 8) +
  theme_light() +
  guides(fill=FALSE) 

# plotting the stats
cm_st_p <-  tableGrob(cm_st)

# all together
grid.arrange(cm_d_p, cm_st_p,nrow = 1, ncol = 2, 
             top=textGrob("Confusion Matrix and Statistics",gp=gpar(fontsize=25,font=1)))

在此处输入图片说明

@cybernetics:令人惊叹的情节兄弟。 我通过这篇文章学到了很多东西。 可以使用此矩形和文本函数在模型摘要中表示结果时完成批次。 惊人的东西。

非常感谢。 祝你未来项目万事如意。

我知道这已经很晚了,但我一直在寻找自己的解决方案。 除了这篇文章之外,还处理了上面的一些以前的答案。 使用ggplot2包和基table函数,我制作了这个简单的函数来绘制一个漂亮的彩色混淆矩阵:

conf_matrix <- function(df.true, df.pred, title = "", true.lab ="True Class", pred.lab ="Predicted Class",
                        high.col = 'red', low.col = 'white') {
  #convert input vector to factors, and ensure they have the same levels
  df.true <- as.factor(df.true)
  df.pred <- factor(df.pred, levels = levels(df.true))
  
  #generate confusion matrix, and confusion matrix as a pecentage of each true class (to be used for color) 
  df.cm <- table(True = df.true, Pred = df.pred)
  df.cm.col <- df.cm / rowSums(df.cm)
  
  #convert confusion matrices to tables, and binding them together
  df.table <- reshape2::melt(df.cm)
  df.table.col <- reshape2::melt(df.cm.col)
  df.table <- left_join(df.table, df.table.col, by =c("True", "Pred"))
  
  #calculate accuracy and class accuracy
  acc.vector <- c(diag(df.cm)) / c(rowSums(df.cm))
  class.acc <- data.frame(Pred = "Class Acc.", True = names(acc.vector), value = acc.vector)
  acc <- sum(diag(df.cm)) / sum(df.cm)
  
  #plot
  ggplot() +
    geom_tile(aes(x=Pred, y=True, fill=value.y),
              data=df.table, size=0.2, color=grey(0.5)) +
    geom_tile(aes(x=Pred, y=True),
              data=df.table[df.table$True==df.table$Pred, ], size=1, color="black", fill = 'transparent') +
    scale_x_discrete(position = "top",  limits = c(levels(df.table$Pred), "Class Acc.")) +
    scale_y_discrete(limits = rev(unique(levels(df.table$Pred)))) +
    labs(x=pred.lab, y=true.lab, fill=NULL,
         title= paste0(title, "\nAccuracy ", round(100*acc, 1), "%")) +
    geom_text(aes(x=Pred, y=True, label=value.x),
              data=df.table, size=4, colour="black") +
    geom_text(data = class.acc, aes(Pred, True, label = paste0(round(100*value), "%"))) +
    scale_fill_gradient(low=low.col, high=high.col, labels = scales::percent,
                        limits = c(0,1), breaks = c(0,0.5,1)) +
    guides(size=F) +
    theme_bw() +
    theme(panel.border = element_blank(), legend.position = "bottom",
          axis.text = element_text(color='black'), axis.ticks = element_blank(),
          panel.grid = element_blank(), axis.text.x.top = element_text(angle = 30, vjust = 0, hjust = 0)) +
    coord_fixed()

} 

您只需复制并粘贴该函数,然后将其保存到您的全局环境中即可。

下面是一个例子:

mydata <- data.frame(true = c("a", "b", "c", "a", "b", "c", "a", "b", "c"),
                     predicted = c("a", "a", "c", "c", "a", "c", "a", "b", "c"))

conf_matrix(mydata$true, mydata$predicted, title = "Conf. Matrix Example")

在此处输入图片说明

cvmsplot_confusion_matrix()以及一些花里胡哨的:


# Create targets and predictions data frame
data <- data.frame(
  "target" = c("A", "B", "A", "B", "A", "B", "A", "B",
               "A", "B", "A", "B", "A", "B", "A", "A"),
  "prediction" = c("B", "B", "A", "A", "A", "B", "B", "B",
                   "B", "B", "A", "B", "A", "A", "A", "A"),
  stringsAsFactors = FALSE
)

# Evaluate predictions and create confusion matrix
eval <- evaluate(
  data = data,
  target_col = "target",
  prediction_cols = "prediction",
  type = "binomial"
)

eval

> # A tibble: 1 x 19
>   `Balanced Accuracy` Accuracy    F1 Sensitivity Specificity `Pos Pred Value` `Neg Pred Value`   AUC `Lower CI`
>                 <dbl>    <dbl> <dbl>       <dbl>       <dbl>            <dbl>            <dbl> <dbl>      <dbl>
> 1               0.690    0.688 0.667       0.714       0.667            0.625             0.75 0.690      0.447
> # … with 10 more variables: Upper CI <dbl>, Kappa <dbl>, MCC <dbl>, Detection Rate <dbl>,
> #   Detection Prevalence <dbl>, Prevalence <dbl>, Predictions <list>, ROC <named list>, Confusion Matrix <list>,
> #   Process <list>

# Plot confusion matrix
# Either supply confusion matrix tibble directly
plot_confusion_matrix(eval[["Confusion Matrix"]][[1]])

# Or plot first confusion matrix in evaluate() output
plot_confusion_matrix(eval)

混淆矩阵图

输出是一个 ggplot 对象。

最简单的方法,结合插入符号:

library(caret)
library(yardstick)
library(ggplot2)

训练模型

plsFit <- train(
  y ~ .,
  data = trainData
)

从模型中获取预测

plsClasses <- predict(plsFit, newdata = testdata)

truth_predicted<-data.frame(
  obs = testdata$y,
  pred = plsClasses
)

制作矩阵。 注意 obs 和 pred 不是字符串

cm <- conf_mat(truth_predicted, obs, pred)

阴谋

autoplot(cm, type = "heatmap") +
  scale_fill_gradient(low="#D6EAF8",high = "#2E86C1")

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM