[英]average of confusion matrix in R
我應用了 10 次交叉驗證,輸出是混淆矩陣的 10 倍,那么如何通過混淆矩陣找到折疊的平均值?
我的工作是正確的嗎?
這是我的代碼:
set.seed(100)
library(caTools)
library(caret)
library(e1071)
folds<-createFolds(wpdc$outcome, k=10)
CV <- lapply(folds, function(x){
traing_folds=wpdc[-x,]
test_folds=wpdc[x,]
dataset_model_nb<-naiveBayes(outcome ~ ., data = traing_folds)
dataset_predict_nB<-predict(dataset_model_nb, test_folds[-1])
dataset_table_nB<-table(test_folds[,1],dataset_predict_nB)
accuracy<-confusionMatrix(dataset_table_nB, positive ="R")
return(accuracy)
})
outcome radius_mean texture_mean perimeter_mean area_mean smoothness_mean compactness_mean concavity_mean concave_points_mean symmetry_mean fractal_dimension_mean radius_se texture_se perimeter_se area_se smoothness_se
1 N 18.02 27.60 117.50 1013.0 0.09489 0.1036 0.1086 0.07055 0.1865 0.06333 0.6249 1.8900 3.972 71.55 0.004433
2 N 17.99 10.38 122.80 1001.0 0.11840 0.2776 0.3001 0.14710 0.2419 0.07871 1.0950 0.9053 8.589 153.40 0.006399
3 N 21.37 17.44 137.50 1373.0 0.08836 0.1189 0.1255 0.08180 0.2333 0.06010 0.5854 0.6105 3.928 82.15 0.006167
我需要同樣的東西,然后按照@Stephen Handerson 的提示,我是:
rfConfusionMatrices <- list()
RrfConfusionMatrix[[i]] <- confMatrix
Reduce
函數對矩陣求和並除以折疊:
rfConfusionMatrixMean <- Reduce('+', rfConfusionMatrix) / nFolds
如果您重新組織代碼並將預測和真實標簽存儲為:
set.seed(100)
library(caTools)
library(caret)
library(e1071)
folds <- createFolds(wpdc$outcome, k=10)
CV <- lapply(folds, function(x){
traing_folds=wpdc[-x,]
test_folds=wpdc[x,]
dataset_model_nb<-naiveBayes(outcome ~ ., data = traing_folds)
dataset_predict_nB<-predict(dataset_model_nb, test_folds[-1])
dataset_table_nB<-table(test_folds[,1],dataset_predict_nB)
return(dataset_table_nB) # storing true and predicted values
})
您可以通過減少來附加它們:
appended_table_nB<- do.call(rbind, dataset_table_nB)
然后取混淆矩陣:
accuracy <- confusionMatrix(appended_table_nB, positive ="R")
這與取平均值相同。 唯一的區別是您對 conf 矩陣中的數據點求和,但准確度和其他指標是它們的平均值。 如果你想看到 conf 矩陣的平均值,你可以:
averaged_matrix <- as.matrix(accuracy) / nFold
我只是在谷歌上搜索以了解從混淆矩陣計算均值是否很常見。 以防萬一有人對可以調整以節省的不僅僅是平均值的解決方案感興趣:
我定義了以下函數來從混淆矩陣或類似對象list
中獲取均值和標准差,前提是所有這些矩陣都具有相同的格式:
average_matr <- function(matr_list){
if(class(matr_list[[1]])[1] == "confusionMatrix"){
matr_lst <- lapply(matr_list, FUN = function(x){x$table})
}else{
matr_lst <- matr_list
}
vals <- lapply(matr_lst, as.numeric)
matr <- do.call(cbind, vals)
#vec_mean <- apply(matr, MARGIN = 1, FUN = mean, na.rm = TRUE)
vec_mean <- rowMeans(matr, na.rm = TRUE)
matr_mean <- matrix(vec_mean, nrow = nrow(matr_lst[[1]]))
vec_sd <- apply(matr, MARGIN = 1, FUN = sd, na.rm = TRUE)
matr_sd <- matrix(vec_sd, nrow = nrow(matr_lst[[1]]))
out <- list(matr_mean, matr_sd)
return(out)
}
average_matr(confusion_matr)
如果列表中的對象屬於confusionMatrix
類,則該函數將只提取值。 如果它是一個矩陣列表,它將計算均值和標准差。
請注意,據我所知, rowMeans
比apply
FUN = mean
應用更快,但據我所知沒有sd
函數。 雖然我使用了類似的語法,但apply
with mean
可以被替換,但對於較小的數據集,應該沒有明顯的區別。
編輯:添加了兩個版本。
附加:包括導出為 LaTeX 表
average_matr <- function(matr_list, latex_file = NA,
metric = "sd", return = TRUE){
if(class(matr_list[[1]])[1] == "confusionMatrix"){
matr_lst <- lapply(matr_list, FUN = function(x){x$table})
}else{
matr_lst <- matr_list
}
vals <- lapply(matr_lst, as.numeric)
matr <- do.call(cbind, vals)
#vec_mean <- apply(matr, MARGIN = 1, FUN = mean, na.rm = TRUE)
vec_mean <- rowMeans(matr, na.rm = TRUE)
matr_mean <- matrix(vec_mean, nrow = nrow(matr_lst[[1]]))
if(metric == "sd"){
vec_sd <- apply(matr, MARGIN = 1, FUN = sd, na.rm = TRUE)
}else if(metric == "se"){
vec_sd <- apply(matr, MARGIN = 1,
FUN = function(x){sd(x, na.rm = TRUE)/sqrt(length(x))})
}else{
vec_sd <- NA
}
if(length(vec_sd) > 1){
matr_sd <- matrix(vec_sd, nrow = nrow(matr_lst[[1]]))
out <- list(matr_mean, matr_sd)
}else{
out <- matr_mean
}
# generate latex table
if(is.character(latex_file)){
if(dir.exists(dirname(latex_file))){
sink(latex_file)
cat("\\hline\n")
cat(paste(row.names(matr_lst[[1]]), collapse = " & "), "\\\\\n")
cat("\\hline\n")
if(length(vec_sd) > 1){
for(r in 1:nrow(matr_mean)){
cat(paste(formatC(matr_mean[r, ], digits = 1, format = "f"),
formatC(matr_sd[r, ], digits = 1, format = "f"),
sep = " \\(\\pm\\) ", collapse = " & "), "\\\\\n")
}
}else{
for(r in 1:nrow(matr_mean)){
cat(paste(formatC(matr_mean, digits = 1, format = "f"),
collapse = " & "), "\\\\\n")
}
}
cat("\\hline\n")
sink()
}else{
warning("Directory not found: ", latex_file)
}
}
if(return){
return(out)
}
}
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.