[英]average of confusion matrix in R
我应用了 10 次交叉验证,输出是混淆矩阵的 10 倍,那么如何通过混淆矩阵找到折叠的平均值?
我的工作是正确的吗?
这是我的代码:
set.seed(100)
library(caTools)
library(caret)
library(e1071)
folds<-createFolds(wpdc$outcome, k=10)
CV <- lapply(folds, function(x){
traing_folds=wpdc[-x,]
test_folds=wpdc[x,]
dataset_model_nb<-naiveBayes(outcome ~ ., data = traing_folds)
dataset_predict_nB<-predict(dataset_model_nb, test_folds[-1])
dataset_table_nB<-table(test_folds[,1],dataset_predict_nB)
accuracy<-confusionMatrix(dataset_table_nB, positive ="R")
return(accuracy)
})
outcome radius_mean texture_mean perimeter_mean area_mean smoothness_mean compactness_mean concavity_mean concave_points_mean symmetry_mean fractal_dimension_mean radius_se texture_se perimeter_se area_se smoothness_se
1 N 18.02 27.60 117.50 1013.0 0.09489 0.1036 0.1086 0.07055 0.1865 0.06333 0.6249 1.8900 3.972 71.55 0.004433
2 N 17.99 10.38 122.80 1001.0 0.11840 0.2776 0.3001 0.14710 0.2419 0.07871 1.0950 0.9053 8.589 153.40 0.006399
3 N 21.37 17.44 137.50 1373.0 0.08836 0.1189 0.1255 0.08180 0.2333 0.06010 0.5854 0.6105 3.928 82.15 0.006167
我需要同样的东西,然后按照@Stephen Handerson 的提示,我是:
rfConfusionMatrices <- list()
RrfConfusionMatrix[[i]] <- confMatrix
Reduce
函数对矩阵求和并除以折叠:
rfConfusionMatrixMean <- Reduce('+', rfConfusionMatrix) / nFolds
如果您重新组织代码并将预测和真实标签存储为:
set.seed(100)
library(caTools)
library(caret)
library(e1071)
folds <- createFolds(wpdc$outcome, k=10)
CV <- lapply(folds, function(x){
traing_folds=wpdc[-x,]
test_folds=wpdc[x,]
dataset_model_nb<-naiveBayes(outcome ~ ., data = traing_folds)
dataset_predict_nB<-predict(dataset_model_nb, test_folds[-1])
dataset_table_nB<-table(test_folds[,1],dataset_predict_nB)
return(dataset_table_nB) # storing true and predicted values
})
您可以通过减少来附加它们:
appended_table_nB<- do.call(rbind, dataset_table_nB)
然后取混淆矩阵:
accuracy <- confusionMatrix(appended_table_nB, positive ="R")
这与取平均值相同。 唯一的区别是您对 conf 矩阵中的数据点求和,但准确度和其他指标是它们的平均值。 如果你想看到 conf 矩阵的平均值,你可以:
averaged_matrix <- as.matrix(accuracy) / nFold
我只是在谷歌上搜索以了解从混淆矩阵计算均值是否很常见。 以防万一有人对可以调整以节省的不仅仅是平均值的解决方案感兴趣:
我定义了以下函数来从混淆矩阵或类似对象list
中获取均值和标准差,前提是所有这些矩阵都具有相同的格式:
average_matr <- function(matr_list){
if(class(matr_list[[1]])[1] == "confusionMatrix"){
matr_lst <- lapply(matr_list, FUN = function(x){x$table})
}else{
matr_lst <- matr_list
}
vals <- lapply(matr_lst, as.numeric)
matr <- do.call(cbind, vals)
#vec_mean <- apply(matr, MARGIN = 1, FUN = mean, na.rm = TRUE)
vec_mean <- rowMeans(matr, na.rm = TRUE)
matr_mean <- matrix(vec_mean, nrow = nrow(matr_lst[[1]]))
vec_sd <- apply(matr, MARGIN = 1, FUN = sd, na.rm = TRUE)
matr_sd <- matrix(vec_sd, nrow = nrow(matr_lst[[1]]))
out <- list(matr_mean, matr_sd)
return(out)
}
average_matr(confusion_matr)
如果列表中的对象属于confusionMatrix
类,则该函数将只提取值。 如果它是一个矩阵列表,它将计算均值和标准差。
请注意,据我所知, rowMeans
比apply
FUN = mean
应用更快,但据我所知没有sd
函数。 虽然我使用了类似的语法,但apply
with mean
可以被替换,但对于较小的数据集,应该没有明显的区别。
编辑:添加了两个版本。
附加:包括导出为 LaTeX 表
average_matr <- function(matr_list, latex_file = NA,
metric = "sd", return = TRUE){
if(class(matr_list[[1]])[1] == "confusionMatrix"){
matr_lst <- lapply(matr_list, FUN = function(x){x$table})
}else{
matr_lst <- matr_list
}
vals <- lapply(matr_lst, as.numeric)
matr <- do.call(cbind, vals)
#vec_mean <- apply(matr, MARGIN = 1, FUN = mean, na.rm = TRUE)
vec_mean <- rowMeans(matr, na.rm = TRUE)
matr_mean <- matrix(vec_mean, nrow = nrow(matr_lst[[1]]))
if(metric == "sd"){
vec_sd <- apply(matr, MARGIN = 1, FUN = sd, na.rm = TRUE)
}else if(metric == "se"){
vec_sd <- apply(matr, MARGIN = 1,
FUN = function(x){sd(x, na.rm = TRUE)/sqrt(length(x))})
}else{
vec_sd <- NA
}
if(length(vec_sd) > 1){
matr_sd <- matrix(vec_sd, nrow = nrow(matr_lst[[1]]))
out <- list(matr_mean, matr_sd)
}else{
out <- matr_mean
}
# generate latex table
if(is.character(latex_file)){
if(dir.exists(dirname(latex_file))){
sink(latex_file)
cat("\\hline\n")
cat(paste(row.names(matr_lst[[1]]), collapse = " & "), "\\\\\n")
cat("\\hline\n")
if(length(vec_sd) > 1){
for(r in 1:nrow(matr_mean)){
cat(paste(formatC(matr_mean[r, ], digits = 1, format = "f"),
formatC(matr_sd[r, ], digits = 1, format = "f"),
sep = " \\(\\pm\\) ", collapse = " & "), "\\\\\n")
}
}else{
for(r in 1:nrow(matr_mean)){
cat(paste(formatC(matr_mean, digits = 1, format = "f"),
collapse = " & "), "\\\\\n")
}
}
cat("\\hline\n")
sink()
}else{
warning("Directory not found: ", latex_file)
}
}
if(return){
return(out)
}
}
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.