簡體   English   中英

R 如何使用 caret 包可視化混淆矩陣

[英]R how to visualize confusion matrix using the caret package

我想可視化我放入混淆矩陣的數據。 有沒有一個函數我可以簡單地放置混淆矩陣,它可以將它可視化(繪制它)?

例如我想做的事情(Matrix$nnet 只是一個包含分類結果的表格):

Confusion$nnet <- confusionMatrix(Matrix$nnet)
plot(Confusion$nnet)

我的 Confusion$nnet$table 看起來像這樣:

    prediction (I would also like to get rid of this string, any help?)
    1  2
1   42 6
2   8 28

您可以僅使用 r 中的 rect 功能來布局混淆矩陣。 這里我們將創建一個函數,該函數允許用戶傳入由 caret 包創建的 cm 對象以產生視覺效果。

讓我們首先創建一個評估數據集,如插入符號演示中所做的那樣

# construct the evaluation dataset
set.seed(144)
true_class <- factor(sample(paste0("Class", 1:2), size = 1000, prob = c(.2, .8), replace = TRUE))
true_class <- sort(true_class)
class1_probs <- rbeta(sum(true_class == "Class1"), 4, 1)
class2_probs <- rbeta(sum(true_class == "Class2"), 1, 2.5)
test_set <- data.frame(obs = true_class,Class1 = c(class1_probs, class2_probs))
test_set$Class2 <- 1 - test_set$Class1
test_set$pred <- factor(ifelse(test_set$Class1 >= .5, "Class1", "Class2"))

現在讓我們使用插入符號來計算混淆矩陣

# calculate the confusion matrix
cm <- confusionMatrix(data = test_set$pred, reference = test_set$obs)

現在我們創建一個函數,根據需要布置矩形,以更具視覺吸引力的方式展示混淆矩陣

draw_confusion_matrix <- function(cm) {

  layout(matrix(c(1,1,2)))
  par(mar=c(2,2,2,2))
  plot(c(100, 345), c(300, 450), type = "n", xlab="", ylab="", xaxt='n', yaxt='n')
  title('CONFUSION MATRIX', cex.main=2)

  # create the matrix 
  rect(150, 430, 240, 370, col='#3F97D0')
  text(195, 435, 'Class1', cex=1.2)
  rect(250, 430, 340, 370, col='#F7AD50')
  text(295, 435, 'Class2', cex=1.2)
  text(125, 370, 'Predicted', cex=1.3, srt=90, font=2)
  text(245, 450, 'Actual', cex=1.3, font=2)
  rect(150, 305, 240, 365, col='#F7AD50')
  rect(250, 305, 340, 365, col='#3F97D0')
  text(140, 400, 'Class1', cex=1.2, srt=90)
  text(140, 335, 'Class2', cex=1.2, srt=90)

  # add in the cm results 
  res <- as.numeric(cm$table)
  text(195, 400, res[1], cex=1.6, font=2, col='white')
  text(195, 335, res[2], cex=1.6, font=2, col='white')
  text(295, 400, res[3], cex=1.6, font=2, col='white')
  text(295, 335, res[4], cex=1.6, font=2, col='white')

  # add in the specifics 
  plot(c(100, 0), c(100, 0), type = "n", xlab="", ylab="", main = "DETAILS", xaxt='n', yaxt='n')
  text(10, 85, names(cm$byClass[1]), cex=1.2, font=2)
  text(10, 70, round(as.numeric(cm$byClass[1]), 3), cex=1.2)
  text(30, 85, names(cm$byClass[2]), cex=1.2, font=2)
  text(30, 70, round(as.numeric(cm$byClass[2]), 3), cex=1.2)
  text(50, 85, names(cm$byClass[5]), cex=1.2, font=2)
  text(50, 70, round(as.numeric(cm$byClass[5]), 3), cex=1.2)
  text(70, 85, names(cm$byClass[6]), cex=1.2, font=2)
  text(70, 70, round(as.numeric(cm$byClass[6]), 3), cex=1.2)
  text(90, 85, names(cm$byClass[7]), cex=1.2, font=2)
  text(90, 70, round(as.numeric(cm$byClass[7]), 3), cex=1.2)

  # add in the accuracy information 
  text(30, 35, names(cm$overall[1]), cex=1.5, font=2)
  text(30, 20, round(as.numeric(cm$overall[1]), 3), cex=1.4)
  text(70, 35, names(cm$overall[2]), cex=1.5, font=2)
  text(70, 20, round(as.numeric(cm$overall[2]), 3), cex=1.4)
}  

最后,傳入我們在使用插入符號創建混淆矩陣時計算的cm 對象

draw_confusion_matrix(cm)

結果如下:

來自 caret 包的混淆矩陣的可視化

您可以使用內置的fourfoldplot 例如,

ctable <- as.table(matrix(c(42, 6, 8, 28), nrow = 2, byrow = TRUE))
fourfoldplot(ctable, color = c("#CC6666", "#99CC99"),
             conf.level = 0, margin = 1, main = "Confusion Matrix")

在此處輸入圖片說明

您可以使用函數conf_mat()yardstick加上autoplot()在幾行得到一個相當不錯的結果。

另外,您仍然可以使用基本的ggplot來修復樣式。

library(yardstick)
library(ggplot2)


# The confusion matrix from a single assessment set (i.e. fold)
cm <- conf_mat(truth_predicted, obs, pred)

autoplot(cm, type = "heatmap") +
  scale_fill_gradient(low="#D6EAF8",high = "#2E86C1")

在此處輸入圖片說明


作為進一步自定義的示例,使用ggplot您還可以添加回圖例:

+ theme(legend.position = "right")

更改圖例的名稱也很容易: + labs(fill="legend_name")

在此處輸入圖片說明

數據示例:

set.seed(123)
truth_predicted <- data.frame(
  obs = sample(0:1,100, replace = T),
  pred = sample(0:1,100, replace = T)
)
truth_predicted$obs <- as.factor(truth_predicted$obs)
truth_predicted$pred <- as.factor(truth_predicted$pred)

我真的很喜歡@Cyber​​netic 漂亮的混淆矩陣可視化,並做了兩個調整以希望進一步改進它。

1) 我用類的實際值交換了 Class1 和 Class2。 2)我用基於百分位數生成紅色(未命中)和綠色(命中)的函數替換橙色和藍色。 我們的想法是快速查看問題/成功的位置及其大小。

截圖和代碼:

混淆矩陣代碼更新

draw_confusion_matrix <- function(cm) {

  total <- sum(cm$table)
  res <- as.numeric(cm$table)

  # Generate color gradients. Palettes come from RColorBrewer.
  greenPalette <- c("#F7FCF5","#E5F5E0","#C7E9C0","#A1D99B","#74C476","#41AB5D","#238B45","#006D2C","#00441B")
  redPalette <- c("#FFF5F0","#FEE0D2","#FCBBA1","#FC9272","#FB6A4A","#EF3B2C","#CB181D","#A50F15","#67000D")
  getColor <- function (greenOrRed = "green", amount = 0) {
    if (amount == 0)
      return("#FFFFFF")
    palette <- greenPalette
    if (greenOrRed == "red")
      palette <- redPalette
    colorRampPalette(palette)(100)[10 + ceiling(90 * amount / total)]
  }

  # set the basic layout
  layout(matrix(c(1,1,2)))
  par(mar=c(2,2,2,2))
  plot(c(100, 345), c(300, 450), type = "n", xlab="", ylab="", xaxt='n', yaxt='n')
  title('CONFUSION MATRIX', cex.main=2)

  # create the matrix 
  classes = colnames(cm$table)
  rect(150, 430, 240, 370, col=getColor("green", res[1]))
  text(195, 435, classes[1], cex=1.2)
  rect(250, 430, 340, 370, col=getColor("red", res[3]))
  text(295, 435, classes[2], cex=1.2)
  text(125, 370, 'Predicted', cex=1.3, srt=90, font=2)
  text(245, 450, 'Actual', cex=1.3, font=2)
  rect(150, 305, 240, 365, col=getColor("red", res[2]))
  rect(250, 305, 340, 365, col=getColor("green", res[4]))
  text(140, 400, classes[1], cex=1.2, srt=90)
  text(140, 335, classes[2], cex=1.2, srt=90)

  # add in the cm results
  text(195, 400, res[1], cex=1.6, font=2, col='white')
  text(195, 335, res[2], cex=1.6, font=2, col='white')
  text(295, 400, res[3], cex=1.6, font=2, col='white')
  text(295, 335, res[4], cex=1.6, font=2, col='white')

  # add in the specifics 
  plot(c(100, 0), c(100, 0), type = "n", xlab="", ylab="", main = "DETAILS", xaxt='n', yaxt='n')
  text(10, 85, names(cm$byClass[1]), cex=1.2, font=2)
  text(10, 70, round(as.numeric(cm$byClass[1]), 3), cex=1.2)
  text(30, 85, names(cm$byClass[2]), cex=1.2, font=2)
  text(30, 70, round(as.numeric(cm$byClass[2]), 3), cex=1.2)
  text(50, 85, names(cm$byClass[5]), cex=1.2, font=2)
  text(50, 70, round(as.numeric(cm$byClass[5]), 3), cex=1.2)
  text(70, 85, names(cm$byClass[6]), cex=1.2, font=2)
  text(70, 70, round(as.numeric(cm$byClass[6]), 3), cex=1.2)
  text(90, 85, names(cm$byClass[7]), cex=1.2, font=2)
  text(90, 70, round(as.numeric(cm$byClass[7]), 3), cex=1.2)

  # add in the accuracy information 
  text(30, 35, names(cm$overall[1]), cex=1.5, font=2)
  text(30, 20, round(as.numeric(cm$overall[1]), 3), cex=1.4)
  text(70, 35, names(cm$overall[2]), cex=1.5, font=2)
  text(70, 20, round(as.numeric(cm$overall[2]), 3), cex=1.4)
}

這是一個簡單的基於ggplot2的想法,可以根據需要進行更改,我正在使用此鏈接中的數據:

#data
confusionMatrix(iris$Species, sample(iris$Species))
newPrior <- c(.05, .8, .15)
names(newPrior) <- levels(iris$Species)

cm <-confusionMatrix(iris$Species, sample(iris$Species))

現在 cm 是一個混淆矩陣對象,可以取出一些對問題有用的東西:

# extract the confusion matrix values as data.frame
cm_d <- as.data.frame(cm$table)
# confusion matrix statistics as data.frame
cm_st <-data.frame(cm$overall)
# round the values
cm_st$cm.overall <- round(cm_st$cm.overall,2)

# here we also have the rounded percentage values
cm_p <- as.data.frame(prop.table(cm$table))
cm_d$Perc <- round(cm_p$Freq*100,2)

現在我們准備好繪制:

library(ggplot2)     # to plot
library(gridExtra)   # to put more
library(grid)        # plot together

# plotting the matrix
cm_d_p <-  ggplot(data = cm_d, aes(x = Prediction , y =  Reference, fill = Freq))+
  geom_tile() +
  geom_text(aes(label = paste("",Freq,",",Perc,"%")), color = 'red', size = 8) +
  theme_light() +
  guides(fill=FALSE) 

# plotting the stats
cm_st_p <-  tableGrob(cm_st)

# all together
grid.arrange(cm_d_p, cm_st_p,nrow = 1, ncol = 2, 
             top=textGrob("Confusion Matrix and Statistics",gp=gpar(fontsize=25,font=1)))

在此處輸入圖片說明

@cybernetics:令人驚嘆的情節兄弟。 我通過這篇文章學到了很多東西。 可以使用此矩形和文本函數在模型摘要中表示結果時完成批次。 驚人的東西。

非常感謝。 祝你未來項目萬事如意。

我知道這已經很晚了,但我一直在尋找自己的解決方案。 除了這篇文章之外,還處理了上面的一些以前的答案。 使用ggplot2包和基table函數,我制作了這個簡單的函數來繪制一個漂亮的彩色混淆矩陣:

conf_matrix <- function(df.true, df.pred, title = "", true.lab ="True Class", pred.lab ="Predicted Class",
                        high.col = 'red', low.col = 'white') {
  #convert input vector to factors, and ensure they have the same levels
  df.true <- as.factor(df.true)
  df.pred <- factor(df.pred, levels = levels(df.true))
  
  #generate confusion matrix, and confusion matrix as a pecentage of each true class (to be used for color) 
  df.cm <- table(True = df.true, Pred = df.pred)
  df.cm.col <- df.cm / rowSums(df.cm)
  
  #convert confusion matrices to tables, and binding them together
  df.table <- reshape2::melt(df.cm)
  df.table.col <- reshape2::melt(df.cm.col)
  df.table <- left_join(df.table, df.table.col, by =c("True", "Pred"))
  
  #calculate accuracy and class accuracy
  acc.vector <- c(diag(df.cm)) / c(rowSums(df.cm))
  class.acc <- data.frame(Pred = "Class Acc.", True = names(acc.vector), value = acc.vector)
  acc <- sum(diag(df.cm)) / sum(df.cm)
  
  #plot
  ggplot() +
    geom_tile(aes(x=Pred, y=True, fill=value.y),
              data=df.table, size=0.2, color=grey(0.5)) +
    geom_tile(aes(x=Pred, y=True),
              data=df.table[df.table$True==df.table$Pred, ], size=1, color="black", fill = 'transparent') +
    scale_x_discrete(position = "top",  limits = c(levels(df.table$Pred), "Class Acc.")) +
    scale_y_discrete(limits = rev(unique(levels(df.table$Pred)))) +
    labs(x=pred.lab, y=true.lab, fill=NULL,
         title= paste0(title, "\nAccuracy ", round(100*acc, 1), "%")) +
    geom_text(aes(x=Pred, y=True, label=value.x),
              data=df.table, size=4, colour="black") +
    geom_text(data = class.acc, aes(Pred, True, label = paste0(round(100*value), "%"))) +
    scale_fill_gradient(low=low.col, high=high.col, labels = scales::percent,
                        limits = c(0,1), breaks = c(0,0.5,1)) +
    guides(size=F) +
    theme_bw() +
    theme(panel.border = element_blank(), legend.position = "bottom",
          axis.text = element_text(color='black'), axis.ticks = element_blank(),
          panel.grid = element_blank(), axis.text.x.top = element_text(angle = 30, vjust = 0, hjust = 0)) +
    coord_fixed()

} 

您只需復制並粘貼該函數,然后將其保存到您的全局環境中即可。

下面是一個例子:

mydata <- data.frame(true = c("a", "b", "c", "a", "b", "c", "a", "b", "c"),
                     predicted = c("a", "a", "c", "c", "a", "c", "a", "b", "c"))

conf_matrix(mydata$true, mydata$predicted, title = "Conf. Matrix Example")

在此處輸入圖片說明

cvmsplot_confusion_matrix()以及一些花里胡哨的:


# Create targets and predictions data frame
data <- data.frame(
  "target" = c("A", "B", "A", "B", "A", "B", "A", "B",
               "A", "B", "A", "B", "A", "B", "A", "A"),
  "prediction" = c("B", "B", "A", "A", "A", "B", "B", "B",
                   "B", "B", "A", "B", "A", "A", "A", "A"),
  stringsAsFactors = FALSE
)

# Evaluate predictions and create confusion matrix
eval <- evaluate(
  data = data,
  target_col = "target",
  prediction_cols = "prediction",
  type = "binomial"
)

eval

> # A tibble: 1 x 19
>   `Balanced Accuracy` Accuracy    F1 Sensitivity Specificity `Pos Pred Value` `Neg Pred Value`   AUC `Lower CI`
>                 <dbl>    <dbl> <dbl>       <dbl>       <dbl>            <dbl>            <dbl> <dbl>      <dbl>
> 1               0.690    0.688 0.667       0.714       0.667            0.625             0.75 0.690      0.447
> # … with 10 more variables: Upper CI <dbl>, Kappa <dbl>, MCC <dbl>, Detection Rate <dbl>,
> #   Detection Prevalence <dbl>, Prevalence <dbl>, Predictions <list>, ROC <named list>, Confusion Matrix <list>,
> #   Process <list>

# Plot confusion matrix
# Either supply confusion matrix tibble directly
plot_confusion_matrix(eval[["Confusion Matrix"]][[1]])

# Or plot first confusion matrix in evaluate() output
plot_confusion_matrix(eval)

混淆矩陣圖

輸出是一個 ggplot 對象。

最簡單的方法,結合插入符號:

library(caret)
library(yardstick)
library(ggplot2)

訓練模型

plsFit <- train(
  y ~ .,
  data = trainData
)

從模型中獲取預測

plsClasses <- predict(plsFit, newdata = testdata)

truth_predicted<-data.frame(
  obs = testdata$y,
  pred = plsClasses
)

制作矩陣。 注意 obs 和 pred 不是字符串

cm <- conf_mat(truth_predicted, obs, pred)

陰謀

autoplot(cm, type = "heatmap") +
  scale_fill_gradient(low="#D6EAF8",high = "#2E86C1")

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM