[英]How can I empose the ntree parameter into the train() function of caret package?
I am using the following function to do cross-validation with the random forest algorithm on my dataset.我正在使用以下 function 在我的数据集上使用随机森林算法进行交叉验证。 However, ntree raises an error, saying that it is not used in the function.
但是ntree报错,说function中没有用到。 Even though I have seen that usage as a recommendation comment before in one of the threads regarding this issue, it did not work at me.
尽管我之前在关于这个问题的一个线程中看到了这种用法作为推荐评论,但它对我不起作用。 Here is my code:
这是我的代码:
cv_rf_class1 <- train(y_train_u ~ ., x_train_u ,
method ="cforest",
trControl = trainControl(method = "cv",
number = 10,
verboseIter = TRUE),
ntree = 100))
If I cannot change the ntree parameter, it uses 500 trees as default in the function and it raises another error for me (subscript out of bounds), so I cannot make it work for my problem.如果我无法更改 ntree 参数,它在 function 中默认使用 500 棵树,并且它会引发另一个错误(下标越界),所以我无法让它解决我的问题。 How can I fix this issue in order to make my function work?
如何解决此问题以使我的 function 正常工作?
ntree
needs to be an argument of train
, and not of trainControl
as you have used it here; ntree
需要是train
的参数,而不是trainControl
的参数,因为您在这里使用它; from the documentation of train
:来自
train
的文档:
...
...
arguments passed to the classification or regression routine (such asrandomForest
).arguments 传递给分类或回归例程(例如
randomForest
)。 Errors will occur if values for tuning parameters are passed here.如果在此处传递调整参数的值,则会发生错误。
Notice also that you are not passing the data in the correct form;另请注意,您没有以正确的形式传递数据;
train
expects the data as (x, y)
, and not as you are passing them (an incorrect combination of formula and matrices). train
期望数据为(x, y)
,而不是您传递它们时(公式和矩阵的错误组合)。
All in all, change your train
call to:总而言之,将您的
train
呼叫更改为:
cv_rf_class1 <- train(x_train_u, y_train_u,
method ="cforest",
ntree = 100,
trControl = trainControl(method = "cv",
number = 10,
verboseIter = TRUE))
UPDATE (after comments)更新(评论后)
Well, it seems that cforest
in particular will not accept an ntree
argument, because, in contrast with the original randomForest
package, this is not how you pass the number of trees in the underlying cforest
function of the respective package ( docs ). Well, it seems that
cforest
in particular will not accept an ntree
argument, because, in contrast with the original randomForest
package, this is not how you pass the number of trees in the underlying cforest
function of the respective package ( docs ).
The correct way, as demonstrated in the relevant examples in the caret
Github repo , is:如
caret
Github repo中的相关示例所示,正确的方法是:
cv_rf_class1 <- train(x_train_u, y_train_u,
method ="cforest",
trControl = trainControl(method = "cv",
number = 10,
verboseIter = TRUE),
controls = party::cforest_unbiased(ntree = 100))
Adapting cforest.R
, we get:适配
cforest.R
,我们得到:
library(caret)
library(plyr)
library(recipes)
library(dplyr)
model <- "cforest"
set.seed(2)
training <- twoClassSim(50, linearVars = 2)
testing <- twoClassSim(500, linearVars = 2)
trainX <- training[, -ncol(training)]
trainY <- training$Class
rec_cls <- recipe(Class ~ ., data = training) %>%
step_center(all_predictors()) %>%
step_scale(all_predictors())
seeds <- vector(mode = "list", length = nrow(training) + 1)
seeds <- lapply(seeds, function(x) 1:20)
cctrl1 <- trainControl(method = "cv", number = 3, returnResamp = "all",
classProbs = TRUE,
summaryFunction = twoClassSummary,
seeds = seeds)
set.seed(849)
test_class_cv_model <- train(trainX, trainY,
method = "cforest",
trControl = cctrl1,
metric = "ROC",
preProc = c("center", "scale"),
controls = party::cforest_unbiased(ntree = 20)) # WORKS OK
test_class_pred <- predict(test_class_cv_model, testing[, -ncol(testing)])
test_class_prob <- predict(test_class_cv_model, testing[, -ncol(testing)], type = "prob")
head(test_class_pred)
# [1] Class2 Class2 Class2 Class1 Class1 Class1
# Levels: Class1 Class2
head(test_class_prob)
# Class1 Class2
# 1 0.4996686 0.5003314
# 2 0.4333222 0.5666778
# 3 0.3625118 0.6374882
# 4 0.5373396 0.4626604
# 5 0.6174159 0.3825841
# 6 0.5327283 0.4672717
Output of sessionInfo()
: sessionInfo()
的 Output :
R version 3.6.1 (2019-07-05)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 7 x64 (build 7601) Service Pack 1
Matrix products: default
locale:
[1] LC_COLLATE=English_United Kingdom.1252 LC_CTYPE=English_United Kingdom.1252 LC_MONETARY=English_United Kingdom.1252
[4] LC_NUMERIC=C LC_TIME=English_United Kingdom.1252
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] recipes_0.1.7 dplyr_0.8.3 plyr_1.8.4 caret_6.0-84 ggplot2_3.2.1 lattice_0.20-38
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.