繁体   English   中英

如何使用 Caret 为每个交叉验证绘制 ROC 曲线

[英]How to plot ROC curves for every cross-validations using Caret

我有以下代码:

library(mlbench)
library(caret)
library(ggplot2)
set.seed(998)

# Prepare data ------------------------------------------------------------

data(Sonar)
my_data <- Sonar

# Cross Validation Definition ---------------------------------------------------

fitControl <-
  trainControl(
    method = "cv",
    number = 10,
    classProbs = T,
    savePredictions = T,
    summaryFunction = twoClassSummary
  )


# Training with Random Forest ----------------------------------------------------------------


model <- train(
  Class ~ .,
  data = my_data,
  method = "rf",
  trControl = fitControl,
  metric = "ROC"
)

for_lift <- data.frame(Class = model$pred$obs, rf = model$pred$R)
lift_obj <- lift(Class ~ rf, data = for_lift, class = "R")


# Plot ROC ----------------------------------------------------------------

ggplot(lift_obj$data) +
  geom_line(aes(1 - Sp, Sn, color = liftModelVar)) +
  scale_color_discrete(guide = guide_legend(title = "method"))

它产生了这个情节。

请注意,我正在执行 10 折交叉验证。 ROC 曲线产生的只有最后的平均值。

我想要做的是为每个交叉验证提供 10 条 ROC 曲线。 我怎样才能做到这一点?

library(mlbench)
library(caret)
library(ggplot2)
set.seed(998)

# Prepare data ------------------------------------------------------------

data(Sonar)
my_data <- Sonar

# Cross Validation Definition ---------------------------------------------------

fitControl <-
  trainControl(
    method = "cv",
    number = 10,
    classProbs = T,
    savePredictions = T,
    summaryFunction = twoClassSummary
  )


# Training with Random Forest ----------------------------------------------------------------


model <- train(
  Class ~ .,
  data = my_data,
  method = "rf",
  trControl = fitControl,
  metric = "ROC"
)


for_lift <- data.frame(Class = model$pred$obs, rf = model$pred$R, resample = model$pred$Resample)
lift_df <-  data.frame()
for (fold in unique(for_lift$resample)) {
  fold_df <- dplyr::filter(for_lift, resample == fold)
  lift_obj_data <- lift(Class ~ rf, data = fold_df, class = "R")$data
  lift_obj_data$fold = fold
  lift_df = rbind(lift_df, lift_obj_data)
}
lift_obj <- lift(Class ~ rf, data = for_lift, class = "R")


# Plot ROC ----------------------------------------------------------------

ggplot(lift_df) +
  geom_line(aes(1 - Sp, Sn, color = fold)) +
  scale_color_discrete(guide = guide_legend(title = "Fold"))

阴谋

计算 AUC:

model <- train(
  Class ~ .,
  data = my_data,
  method = "rf",
  trControl = fitControl,
  metric = "ROC"
)

library(plyr)
library(MLmetrics)
ddply(model$pred, "Resample", summarise,
      accuracy = Accuracy(pred, obs))

输出:

   Resample  accuracy
1    Fold01 0.8253968
2    Fold02 0.8095238
3    Fold03 0.8000000
4    Fold04 0.8253968
5    Fold05 0.8095238
6    Fold06 0.8253968
7    Fold07 0.8333333
8    Fold08 0.8253968
9    Fold09 0.9841270
10   Fold10 0.7936508

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM