简体   繁体   English

绘制 tidymodels 的决策树结果

[英]Plotting decision tree results from tidymodels

I have managed to build a decision tree model using the tidymodels package but I am unsure how to pull the results and plot the tree.我已经设法使用tidymodels package 构建了一个决策树 model 但我不确定如何提取结果和 plot 这棵树。 I know I can use the rpart and rpart.plot packages to achieve the same thing but I would rather use tidymodels as that is what I am learning.我知道我可以使用rpartrpart.plot包来实现同样的事情,但我宁愿使用tidymodels ,因为这就是我正在学习的。 Below is an example using the mtcars data.下面是一个使用mtcars数据的例子。

library(tidymodels)
library(rpart)
library(rpart.plot)
library(dplyr) #contains mtcars

#data
df <- mtcars %>%
    mutate(gear = factor(gear))


#train/test
set.seed(1234)

df_split <- initial_split(df)
df_train <- training(df_split)
df_test <- testing(df_split)


df_recipe <- recipe(gear~ ., data = df) %>%
  step_normalize(all_numeric())


#building model
tree <- decision_tree() %>%
   set_engine("rpart") %>%
   set_mode("classification")

#workflow
 tree_wf <- workflow() %>%
   add_recipe(df_recipe) %>%
   add_model(tree) %>%
   fit(df_train) #results are found here 

rpart.plot(tree_wf$fit$fit) #error is here

The error I get says Error in rpart.plot(tree_wf$fit$fit): Not an rpart object which makes sense but I am unaware if there is a package or step I am missing to convert the results into a format that rpart.plot will allow me to plot. This might not be possible but any help would be much appreciated.我得到的Error in rpart.plot(tree_wf$fit$fit): Not an rpart object这是有道理的,但我不知道是否有 package 或我缺少将结果转换为rpart.plot格式的步骤将允许我拨打 plot。这可能不可能,但我们将不胜感激。

You can also use the workflows::pull_workflow_fit() function. It makes the code a little bit more elegant.您还可以使用workflows::pull_workflow_fit() function。它使代码更优雅一些。

tree_fit <- tree_wf %>% 
  pull_workflow_fit()
rpart.plot(tree_fit$fit)

The following works (note the extra $fit ):以下作品(注意额外的$fit ):

rpart.plot(tree_wf$fit$fit$fit)

Not a very elegant solution, but it does plot the tree.这不是一个非常优雅的解决方案,但它确实 plot 树。

Tested with parsnip 0.1.3 and rpart.plot 3.0.8.使用防风草 0.1.3 和 rpart.plot 3.0.8 进行测试。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM