簡體   English   中英

使用帶有gbm方法的插入符號進行多類分類

[英]Usage of caret with gbm method for multiclass classification

我正在解決多類分類問題,並嘗試使用廣義Boosted模型(R中的gbm包)。 我遇到的問題:使用method="gbm"插入符號train函數似乎無法正確處理多類數據。 下面給出一個簡單的例子。

library(gbm)
library(caret)
data(iris)
fitControl <- trainControl(method="repeatedcv",
                           number=5,
                           repeats=1,
                           verboseIter=TRUE)
set.seed(825)
gbmFit <- train(Species ~ ., data=iris,
                method="gbm",
                trControl=fitControl,
                verbose=FALSE)
gbmFit

輸出是

+ Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
predictions failed for Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
- Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
+ Fold1.Rep1: interaction.depth=2, shrinkage=0.1, n.trees=150 
...
+ Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
predictions failed for Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
- Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
Aggregating results
Selecting tuning parameters
Fitting interaction.depth = numeric(0), n.trees = numeric(0), shrinkage = numeric(0) on full training set
Error in if (interaction.depth < 1) { : argument is of length zero

然而,如果我嘗試使用沒有插入包裝的gbm,我會得到很好的結果。

set.seed(1365)
train <- createDataPartition(iris$Species, p=0.7, list=F)
train.iris <- iris[train,]
valid.iris <- iris[-train,]
gbm.fit.iris <- gbm(Species ~ ., data=train.iris, n.trees=200, verbose=FALSE)
gbm.pred <- predict(gbm.fit.iris, valid.iris, n.trees=200, type="response")
gbm.pred <- as.factor(colnames(gbm.pred)[max.col(gbm.pred)]) ##!
confusionMatrix(gbm.pred, valid.iris$Species)$overall

僅供參考,代碼在線標有##! predict.gbm返回的類概率矩陣轉換為最可能類的因子。 輸出是

      Accuracy          Kappa  AccuracyLower  AccuracyUpper   AccuracyNull AccuracyPValue  McnemarPValue 
  9.111111e-01   8.666667e-01   7.877883e-01   9.752470e-01   3.333333e-01   8.467252e-16            NaN 

有關如何使gtm在多類數據上正常工作的任何建議嗎?

UPD:

sessionInfo()
R version 2.15.3 (2013-03-01)
Platform: x86_64-pc-linux-gnu (64-bit)

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C               LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8    LC_PAPER=C                 LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C             LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] splines   stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] e1071_1.6-1      class_7.3-5      gbm_2.0-8        survival_2.36-14 caret_5.15-61    reshape2_1.2.2   plyr_1.8        
 [8] lattice_0.20-13  foreach_1.4.0    cluster_1.14.3   compare_0.2-3   

loaded via a namespace (and not attached):
[1] codetools_0.2-8 compiler_2.15.3 grid_2.15.3     iterators_1.0.6 stringr_0.6.2   tools_2.15.3   

這是我正在研究的一個問題。

如果你發布了sessionInfo()的結果會有所幫助。

此外,從https://code.google.com/p/gradientboostedmodels/獲取最新的gbm可能會解決問題。

馬克斯

更新: Caret可以進行多級分類。

您應確保類標簽采用字母數字格式(以字母開頭)。

例如:如果您的數據標簽為“1”,“2”,“3”,則將其更改為“Seg1”,“Seg2”和“Seg3”,否則將失敗。

更新:原始代碼確實運行並產生以下輸出

+ Fold1.Rep1: shrinkage=0.1, interaction.depth=1, n.trees=150 
- Fold1.Rep1: shrinkage=0.1, interaction.depth=1, n.trees=150 
...
...
...
+ Fold5.Rep1: shrinkage=0.1, interaction.depth=3, n.trees=150 
- Fold5.Rep1: shrinkage=0.1, interaction.depth=3, n.trees=150 
Aggregating results
Selecting tuning parameters
Fitting n.trees = 50, interaction.depth = 2, shrinkage = 0.1 on full training set
> gbmFit
Stochastic Gradient Boosting 

150 samples
  4 predictor
  3 classes: 'setosa', 'versicolor', 'virginica' 

No pre-processing
Resampling: Cross-Validated (5 fold, repeated 1 times) 

Summary of sample sizes: 120, 120, 120, 120, 120 

Resampling results across tuning parameters:

  interaction.depth  n.trees  Accuracy   Kappa  Accuracy SD
  1                   50      0.9400000  0.91   0.04346135 
  1                  100      0.9400000  0.91   0.03651484 
  1                  150      0.9333333  0.90   0.03333333 
  2                   50      0.9533333  0.93   0.04472136 
  2                  100      0.9533333  0.93   0.05055250 
  2                  150      0.9466667  0.92   0.04472136 
  3                   50      0.9333333  0.90   0.03333333 
  3                  100      0.9466667  0.92   0.04472136 
  3                  150      0.9400000  0.91   0.03651484 
  Kappa SD  
  0.06519202
  0.05477226
  0.05000000
  0.06708204
  0.07582875
  0.06708204
  0.05000000
  0.06708204
  0.05477226

Tuning parameter 'shrinkage' was held constant at a value of 0.1
Accuracy was used to select the optimal model using  the
 largest value.
The final values used for the model were n.trees =
 50, interaction.depth = 2 and shrinkage = 0.1. 
> summary(gbmFit)
                      var    rel.inf
Petal.Length Petal.Length 74.1266408
Petal.Width   Petal.Width 22.0668983
Sepal.Width   Sepal.Width  3.2209288
Sepal.Length Sepal.Length  0.5855321

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM