简体   繁体   中英

ggparty and tidymodels, cannot plot final node graphs, no data attached to model(?)

I am playing with tidymodels workflow for ctree with new bonsai package, an extension for modeling with partykit, here is my code:

pacman::p_load(tidymodels, bonsai, modeldata, finetune)

data(penguins)

doParallel::registerDoParallel()


split <- initial_split(penguins, strata = species)
df_train <- training(split)
df_test <- testing(split)

folds <- 
  # vfold_cv(train, strata = penguins)
  bootstraps(df_train, strata = species, times = 5) # if small number of records


tree_recipe <-
  recipe(formula = species ~ flipper_length_mm + island, data = df_train) 

tree_spec <-
  decision_tree(
    tree_depth = tune(),
    min_n = tune()
  ) %>%
  set_engine("partykit") %>%
  set_mode("classification") 

tree_workflow <- 
  workflow() %>% 
  add_recipe(tree_recipe) %>% 
  add_model(tree_spec) 

set.seed(8833)
tree_tune <-
  tune_sim_anneal(
    tree_workflow, 
    resamples = folds, 
    iter = 30,
    initial = 4,
    metrics = metric_set(roc_auc, pr_auc, accuracy))


final_workflow <- finalize_workflow(tree_workflow, select_best(tree_tune, "roc_auc"))

final_fit <- last_fit(final_workflow, split = split)

I understand that to extract a final fit model I need to:

final_model <-  extract_fit_parsnip(final_fit)

And then I can plot the tree.

plot(final_model$fit)

I would like to try a different plotting library that works with partykit:

library(ggparty)

ggparty(final_model$fit)+ 
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  geom_node_plot(
    gglist =  list(geom_bar(x = "", color = species),
                               xlab("species")),
                 # draw individual legend for each plot
                 shared_legend = FALSE
  )

But the ggparty code works up to the last line (without it the tree looks good, it prints without plots in final nodes).

It does not see the data inside the fitted model, namely, the response variable species.

    Error in layer(data = data, mapping = mapping, stat = stat, geom = GeomBar,  : 
  object 'species' not found

How can I extract the final fit from tidymodels, so that it contains the fitted values as it would if I had built a model without tidymodels workflow?

There are two problems in your code, only one of them related to tidymodels .

  1. The arguments to geom_bar() need to be wrapped in aes() , which is necessary both for plain ctree() output and for the result from the tidymodels workflow.

  2. The dependent variable in the output from the tidymodels workflow is not called species anymore but ..y (presumably a standardized placeholder employed in tidymodels ). This can be seen from simply printing the object:

     final_model$fit ## Model formula: ##..y ~ flipper_length_mm + island ## ## Fitted party: ## [1] root ##...

Addressing both of these (plus using the fill= instead of color= aesthetic) works as intended. ( Bonus comment: autoplot(final_model$fit) also just works!)

ggparty(final_model$fit) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  geom_node_plot(gglist =  list(
    geom_bar(aes(x = "", fill = ..y)),
    xlab("species")
  ))

ggparty可视化

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM