簡體   English   中英

在R中使用randomForest包,如何從分類模型中獲取概率?

[英]Using randomForest package in R, how to get probabilities from classification model?

TL; DR:

我可以在原始 randomForest 調用中 標記一些內容, 以避免重新運行 predict 函數以獲得預測的分類概率,而不僅僅是可能的類別嗎?

細節:

我正在使用randomForest包。

我有一個類似的模型:

model <- randomForest(x=out.data[train.rows, feature.cols],
                      y=out.data[train.rows, response.col],
                      xtest=out.data[test.rows, feature.cols],
                      ytest=out.data[test.rows, response.col],
                      importance= TRUE)

其中out.data是一個數據框,其中feature.cols是數字和分類特征的混合,而response.col是一個TRUE / FALSE二進制變量,我強制插入factor以便randomForest模型將其正確地視為分類。

一切運行良好,變量model正確返回給我。 但是,我似乎無法找到傳遞給randomForest函數的標志或參數,因此modelTRUEFALSE概率返回給我。 相反,我得到的只是預測值。 也就是說,如果我看一下model$predicted ,我會看到類似的東西:

FALSE
FALSE
TRUE
TRUE
FALSE
.
.
.

相反,我希望看到類似的東西:

   FALSE  TRUE
1  0.84   0.16
2  0.66   0.34
3  0.11   0.89
4  0.17   0.83
5  0.92   0.08
.   .      .
.   .      .
.   .      .

我可以得到上述內容,但為了做到這一點,我需要做以下事情:

tmp <- predict(model, out.data[test.rows, feature.cols], "prob")

[ test.rows捕獲模型測試期間使用的行號。 此處未顯示詳細信息,但由於測試行ID輸出到model ,因此很簡單。

一切正常。 問題是模型很大並且需要很長時間才能運行,甚至預測本身也需要一段時間。 由於預測應該完全沒必要(我只是想計算測試數據集上的ROC曲線,應該已經計算過的數據集),我希望跳過這一步。 有沒有我可以在原始的 randomForest 調用中 標記 以避免重新運行 predict 函數?

model$predictedpredict()返回的一樣。 如果你想要TRUEFALSE類的概率那么你必須運行predict() ,或者傳遞x,y,xtest,ytest類的

randomForest(x,y,xtest=x,ytest=y), 

其中x=out.data[, feature.cols], y=out.data[, response.col]

model$predicted返回基於哪個類在每個記錄的model$votes具有較大值的類。 votes ,正如@joran指出的那樣是來自隨機森林的OOB(袋外)'投票'的比例,只有當在OOB樣本中選擇記錄時才進行投票。 另一方面, predict()根據所有樹的投票返回每個類的真實概率。

使用randomForest(x,y,xtest=x,ytest=y)功能與傳遞公式或簡單的randomForest(x,y)時的功能略有不同,如上面給出的示例所示。 randomForest(x,y,xtest=x,ytest=y)將返回每個類的概率,這可能聽起來randomForest(x,y,xtest=x,ytest=y) ,但它在model$test$votes下找到,並且在model$test$predicted類下model$test$predicted ,它只根據model$test$votes哪個類具有較大值來選擇類。 此外,當使用randomForest(x,y,xtest=x,ytest=y)model$predicted randomForest(x,y,xtest=x,ytest=y)model$votes具有與上面相同的定義。

最后,請注意,如果使用randomForest(x,y,xtest=x,ytest=y) ,則為了使用predict()函數,keep.forest標志應設置為TRUE。

model=randomForest(x,y,xtest=x,ytest=y,keep.forest=TRUE). 
prob=predict(model,x,type="prob")

prob 等同於model$test$votes因為測試數據輸入都是x

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM