简体   繁体   中英

partykit: Change the terminal node boxplots to violins

The package partykit offers a plotting function for decision trees plot.constparty() , which can display distributions of the terminal node with boxplots ( node_boxplot() ), minimal example using the iris dataset below.

library("partykit")
ct <- ctree(Petal.Length ~ Sepal.Length + Sepal.Width, data = iris, stump = TRUE)
plot(ct, terminal_panel = node_boxplot)

I would love to display the boxplots as violin plots. Since you can write your own panel functions, that should actually be possible. However, it seems that the violin plot needs to be setup using grid functions, so I have no clue how to do that. I imagine that this is quite cumbersome work, but I believe that many users would benefit from this panel function. Any suggestions on how to implement that? (A first lead points here: partykit: Change terminal node boxplots to bar graphs that shows mean and standard deviation )

Add on: Assume we had a strategy to plot terminal nodes with violins. How could we apply this strategy to multivariate responses to display violins instead of boxplots. See the following screenshot produced with the function node_mvar() : 具有多变量响应的决策树:由 node_mvar() 生成的箱线图

There are two natural strategies for this:

  1. Write a node_violinplot() panel-generating function similar to node_boxplot() .
  2. Use ggplot2 via the ggparty package and leverage the existing geom_violin() .

For the first strategy, I would recommend to copy the code of node_boxplot() (including setting its class,) and rename it to, say node_violinplot() . Most of its code is responsible for setting up the right viewport and axis ranges etc. which can all be preserved. And then one would "only" replace the grid.lines() and grid.rect() for drawing the boxes with the calls for drawing the violin. I'm not sure what would be the best way to compute the coordinates for the violin elements, though.

For the second strategy all building blocks are essentially available and just have to be customized to obtain the kind of violinplot that you would want. Fox example:

ggparty 与 geom_violin 和 geom_boxplot 作为 geom_node_plot

This plot can be replicated as follows:

## example tree
library("partykit")
ct <- ctree(dist ~ speed, data = cars)

## visualization with ggparty + geom_violin
library("ggparty")
ggparty(ct) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  geom_node_plot(gglist = list(
    geom_violin(aes(x = "", y = dist)),
    geom_boxplot(aes(x = "", y = dist), coef = Inf, width = 0.1, fill = "lightgray"),
    xlab(""),
    theme_minimal()
  ))

Here is a version of a node_violinplot() panel-generating function:

node_violinplot <- function (obj, col = "black", fill = "lightgray", bg = "white",
                             width = 0.8, yscale = NULL, ylines = 3, cex = 0.5, id = TRUE,
                             mainlab = NULL, gp = gpar(),
                             col.box = "black", fill.box = "black", fill.median = "white")
{
  y <- obj$fitted[["(response)"]]
  stopifnot(is.numeric(y))
  if (is.null(yscale))
    yscale <- range(y) + c(-0.1, 0.1) * diff(range(y))
  rval <- function(node) {
    nid <- id_node(node)
    dat <- data_party(obj, nid)
    yn <- dat[["(response)"]]
    wn <- dat[["(weights)"]]
    if (is.null(wn))
      wn <- rep(1, length(yn))

    ## compute kernel density estimate
    kde <- stats::density(rep.int(yn, wn), from = yscale[1], to = yscale[2], na.rm = TRUE)
    ## limit kde to range(yn)
    idx <- which(kde$x < range(yn)[2] & kde$x > range(yn)[1])
    kde$y <- kde$y[idx]
    kde$x <- kde$x[idx]

    ## construct polygon coordinates
    width.scalingfactor <- width / 2 / max(kde$y, na.rm = TRUE)
    polX <- c((0.5 - (kde$y * width.scalingfactor)), rev(0.5 + (kde$y * width.scalingfactor)))
    polY <- c(kde$x, rev(kde$x))

    ## compute boxplot characteristics
    x <- boxplot(rep.int(yn, wn), plot = FALSE)

    top_vp <- viewport(layout = grid.layout(nrow = 2, ncol = 3,
                                            widths = unit(c(ylines, 1, 1), c("lines", "null",
                                                                             "lines")), heights = unit(c(1, 1), c("lines",
                                                                                                                  "null"))), width = unit(1, "npc"), height = unit(1,
                                                                                                                                                                   "npc") - unit(2, "lines"), name = paste("node_boxplot",
                                                                                                                                                                                                           nid, sep = ""), gp = gp)
    pushViewport(top_vp)
    grid.rect(gp = gpar(fill = bg, col = 0))
    top <- viewport(layout.pos.col = 2, layout.pos.row = 1)
    pushViewport(top)
    if (is.null(mainlab)) {
      mainlab <- if (id) {
        function(id, nobs) sprintf("Node %s (n = %s)",
                                   id, nobs)
      }
      else {
        function(id, nobs) sprintf("n = %s", nobs)
      }
    }
    if (is.function(mainlab)) {
      mainlab <- mainlab(names(obj)[nid], sum(wn))
    }
    grid.text(mainlab)
    popViewport()
    plot <- viewport(layout.pos.col = 2, layout.pos.row = 2,
                     xscale = c(0, 1), yscale = yscale, name = paste0("node_boxplot",
                                                                      nid, "plot"), clip = FALSE)
    pushViewport(plot)
    grid.yaxis()
    grid.rect(gp = gpar(fill = "transparent"))
    grid.clip()
    ## draw violin
    grid.polygon(unit(polX,"npc"), unit(polY, "native"),
                 gp = gpar(col = col, fill = fill))
    ## draw boxplot
    box.width <- max(polX-0.5, na.rm = TRUE) * 0.08
    grid.rect(unit(0.5, "npc"), unit(x$stats[2], "native"),
              width = unit(box.width, "npc"), height = unit(diff(x$stats[c(2, 4)]), "native"),
              just = c("center", "bottom"),
              gp = gpar(col = col.box, fill = fill.box))
    grid.lines(unit(0.5, "npc"), unit(x$stats[1:2], "native"),
               gp = gpar(col = col))
    grid.lines(unit(0.5, "npc"), unit(x$stats[4:5], "native"),
               gp = gpar(col = col))
    grid.points(unit(0.5, "npc"), unit(x$stats[3], "native"),
                size = unit(0.5, "char"),
                gp = gpar(col = fill.median, fill = fill.median), pch = 19)
    upViewport(2)
  }
  return(rval)
}
class(node_violinplot) <- "grapcon_generator"

And a version of node_mvar_violin() that plots the terminal violins for a multivariate response:

.nobs_party <- function(party, id = 1L) {
  dat <- data_party(party, id = id)
  if("(weights)" %in% names(dat)) sum(dat[["(weights)"]]) else NROW(dat)
}

#' @export
node_mvar_violin <- function(obj, which = NULL, id = TRUE, pop = TRUE, ylines = NULL, mainlab = NULL, varlab = TRUE, bg = "white", terminal_panel_mvar = node_violinplot, ...)
{
  ## obtain dependent variables
  y <- obj$fitted[["(response)"]]

  ## fitted node ids
  fitted <- obj$fitted[["(fitted)"]]

  ## number of panels needed
  if(is.null(which)) which <- 1L:NCOL(y)
  k <- length(which)

  rval <- function(node) {

    tid <- id_node(node)
    nobs <- .nobs_party(obj, id = tid)

    ## set up top viewport
    top_vp <- viewport(layout = grid.layout(nrow = k, ncol = 2,
                                            widths = unit(c(ylines, 1), c("lines", "null")), heights = unit(k, "null")),
                       width = unit(1, "npc"), height = unit(1, "npc") - unit(2, "lines"),
                       name = paste("node_mvar", tid, sep = ""))
    pushViewport(top_vp)
    grid.rect(gp = gpar(fill = bg, col = 0))

    ## main title
    if (is.null(mainlab)) {
      mainlab <- if(id) {
        function(id, nobs) sprintf("Node %s (n = %s)", id, nobs)
      } else {
        function(id, nobs) sprintf("n = %s", nobs)
      }
    }
    if (is.function(mainlab)) {
      mainlab <- mainlab(tid, nobs)
    }

    for(i in 1L:k) {
      tmp <- obj
      tmp$fitted[["(response)"]] <- y[,which[i]]
      if(varlab) {
        nm <- names(y)[which[i]]
        if(i == 1L) nm <- paste(mainlab, nm, sep = ": ")
      } else {
        nm <- if(i == 1L) mainlab else ""
      }
      pfun <- switch(sapply(y, class)[which[i]],
                     "Surv" = node_surv(tmp, id = id, mainlab = nm, ...),
                     "factor" = node_barplot(tmp, id = id, mainlab = nm,  ...),
                     "ordered" = node_barplot(tmp, id = id, mainlab = nm, ...),
                     do.call("terminal_panel_mvar", list(tmp, id = id, mainlab = nm, ...)))
      ## select panel
      plot_vpi <- viewport(layout.pos.col = 2L, layout.pos.row = i)
      pushViewport(plot_vpi)

      ## call panel function
      pfun(node)

      if(pop) popViewport() else upViewport()
    }
    if(pop) popViewport() else upViewport()
  }

  return(rval)
}
class(node_mvar_violin) <- "grapcon_generator"

All in all, the result will look like this:

在此处输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM