簡體   English   中英

更改由 Caret 在 R 創建的 plot 中顯示的調整參數

[英]Change tuning parameters shown in the plot created by Caret in R

I'm using the Caret package in R to train a model by the method called 'xgbTree' in R.

繪制經過訓練的 model 后,如下圖所示:調整參數即 'eta' = 0.2 不是我想要的,因為在訓練 model 之前,我還在 expand.grid 中定義了 eta = 0.1 作為調整參數,這是最好的調整. 所以我想將 plot 中的 eta = 0.2 更改為 plot function 中的 eta = 0.1 的情況。 我怎么能做到? 謝謝你。

在此處輸入圖像描述

set.seed(100)  # For reproducibility

xgb_trcontrol = trainControl(
method = "cv",
#repeats = 2,
number = 10,  
#search = 'random',
allowParallel = TRUE,
verboseIter = FALSE,
returnData = TRUE
)


xgbGrid <- expand.grid(nrounds = c(100,200,1000),  # this is n_estimators in the python code above
                   max_depth = c(6:8),
                   colsample_bytree = c(0.6,0.7),
                   ## The values below are default values in the sklearn-api. 
                   eta = c(0.1,0.2),
                   gamma=0,
                   min_child_weight = c(5:8),
                   subsample = c(0.6,0.7,0.8,0.9)
)


set.seed(0) 
xgb_model8 = train(
x, y_train,  
trControl = xgb_trcontrol,
tuneGrid = xgbGrid,
method = "xgbTree"
)

發生的情況是繪圖設備繪制了網格的所有值,最后出現的是 eta=0.2。 例如:

xgb_trcontrol = trainControl(method = "cv", number = 3,returnData = TRUE)

xgbGrid <- expand.grid(nrounds = c(100,200,1000),  
                   max_depth = c(6:8),
                   colsample_bytree = c(0.6,0.7), 
                   eta = c(0.1,0.2),
                   gamma=0,
                   min_child_weight = c(5:8),
                   subsample = c(0.6,0.7,0.8,0.9)
)

set.seed(0)

x = mtcars[,-1]
y_train = mtcars[,1]

xgb_model8 = train(
x, y_train,  
trControl = xgb_trcontrol,
tuneGrid = xgbGrid,
method = "xgbTree"
)

你可以像這樣保存你的地塊:

pdf("plots.pdf")
plot(xgb_model8,metric="RMSE")
dev.off()

或者如果你想 plot 一個特定的參數,例如 eta = 0.2,你還需要修復colsample_bytree ,否則參數太多:

library(ggplot2)

ggplot(subset(xgb_model8$results
,eta==0.1 & colsample_bytree==0.6),
aes(x=min_child_weight,y=RMSE,group=factor(subsample),col=factor(subsample))) + 
geom_line() + geom_point() + facet_grid(nrounds~max_depth)

在此處輸入圖像描述

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM