简体   繁体   中英

Redefining R's nnet::multinom predict.multinom predict method to support type="link"

I would like to have R's nnet::multinom function being supported by the new marginaleffects package, but marginaleffects::predictions() relies on the predict() methods supplied by the modeling packages to compute predicted values on both the response and link scale. In the case of nnet::multinom , however, the predict() method supplied by nnet does not support predictions on the link scale - it only supports type="probs" or type="class" , https://github.com/vincentarelbundock/marginaleffects/issues/404 . So I would like to redefine the nnet::multinom predict.multinom method so that it would also support type="link" (in the original namespace of that package, so that also the marginaleffects package would see it as having been redefined). Is there any way to accomplish this?

For reference, the predict.multinom method ( https://github.com/cran/nnet/blob/master/R/multinom.R ) now looks like

predict.multinom <- function(object, newdata, type=c("class","probs"), ...)
{
    if(!inherits(object, "multinom")) stop("not a \"multinom\" fit")
    type <- match.arg(type)
    if(missing(newdata)) Y <- fitted(object)
    else {
        newdata <- as.data.frame(newdata)
        rn <- row.names(newdata)
        Terms <- delete.response(object$terms)
        m <- model.frame(Terms, newdata, na.action = na.omit,
                         xlev = object$xlevels)
        if (!is.null(cl <- attr(Terms, "dataClasses")))
            .checkMFClasses(cl, m)
        keep <- match(row.names(m), rn)
        X <- model.matrix(Terms, m, contrasts = object$contrasts)
        Y1 <- predict.nnet(object, X)
        Y <- matrix(NA, nrow(newdata), ncol(Y1),
                    dimnames = list(rn, colnames(Y1)))
        Y[keep, ] <- Y1
    }
    switch(type, class={
        if(length(object$lev) > 2L)
            Y <- factor(max.col(Y), levels=seq_along(object$lev),
                        labels=object$lev)
        if(length(object$lev) == 2L)
            Y <- factor(1 + (Y > 0.5), levels=1L:2L, labels=object$lev)
        if(length(object$lev) == 0L)
            Y <- factor(max.col(Y), levels=seq_along(object$lab),
                        labels=object$lab)
    }, probs={})
    drop(Y)
}

with predict.nnet ( https://github.com/cran/nnet/blob/master/R/nnet.R ) being given by

predict.nnet <- function(object, newdata, type=c("raw","class"), ...)
{
    if(!inherits(object, "nnet")) stop("object not of class \"nnet\"")
    type <- match.arg(type)
    if(missing(newdata)) z <- fitted(object)
    else {
        if(inherits(object, "nnet.formula")) { #
            ## formula fit
            newdata <- as.data.frame(newdata)
            rn <- row.names(newdata)
            ## work hard to predict NA for rows with missing data
            Terms <- delete.response(object$terms)
            m <- model.frame(Terms, newdata, na.action = na.omit,
                             xlev = object$xlevels)
            if (!is.null(cl <- attr(Terms, "dataClasses")))
                .checkMFClasses(cl, m)
            keep <- match(row.names(m), rn)
            x <- model.matrix(Terms, m, contrasts = object$contrasts)
            xint <- match("(Intercept)", colnames(x), nomatch=0L)
            if(xint > 0L) x <- x[, -xint, drop=FALSE] # Bias term is used for intercepts
        } else {
            ## matrix ...  fit
            if(is.null(dim(newdata)))
                dim(newdata) <- c(1L, length(newdata)) # a row vector
            x <- as.matrix(newdata)     # to cope with dataframes
            if(any(is.na(x))) stop("missing values in 'x'")
            keep <- 1L:nrow(x)
            rn <- rownames(x)
        }
        ntr <- nrow(x)
        nout <- object$n[3L]
        .C(VR_set_net,
           as.integer(object$n), as.integer(object$nconn),
           as.integer(object$conn), rep(0.0, length(object$wts)),
           as.integer(object$nsunits), as.integer(0L),
           as.integer(object$softmax), as.integer(object$censored))
        z <- matrix(NA, nrow(newdata), nout,
                    dimnames = list(rn, dimnames(object$fitted.values)[[2L]]))
        z[keep, ] <- matrix(.C(VR_nntest,
                               as.integer(ntr),
                               as.double(x),
                               tclass = double(ntr*nout),
                               as.double(object$wts)
                               )$tclass, ntr, nout)
        .C(VR_unset_net)
    }
    switch(type, raw = z,
           class = {
               if(is.null(object$lev)) stop("inappropriate fit for class")
               if(ncol(z) > 1L) object$lev[max.col(z)]
               else object$lev[1L + (z > 0.5)]
           })
}

I was hoping I could perhaps overwrite the predict.multinom function by the predict.mblogit function ( https://github.com/melff/mclogit/blob/master/pkg/R/mblogit.R ), or something close to it (probably some minor edits needed, due to the mblogit and nnet objects being structured slightly differently):

predict.mblogit <- function(object, newdata=NULL,type=c("link","response"),se.fit=FALSE,...){
  
  type <- match.arg(type)
  mt <- terms(object)
  rhs <- delete.response(mt)
  if(missing(newdata)){
    m <- object$model
    na.act <- object$na.action
  }
  else{
    m <- model.frame(rhs,data=newdata,na.action=na.exclude)
    na.act <- attr(m,"na.action")
  }
  X <- model.matrix(rhs,m,
                    contrasts.arg=object$contrasts,
                    xlev=object$xlevels
  )
  rn <- rownames(X)
  D <- object$D
  XD <- X%x%D
  rspmat <- function(x){
    y <- t(matrix(x,nrow=nrow(D)))
    colnames(y) <- rownames(D)
    y
  }
  
  eta <- c(XD %*% coef(object))
  eta <- rspmat(eta)
  rownames(eta) <- rn
  if(se.fit){
    V <- vcov(object)
    stopifnot(ncol(XD)==ncol(V))
  }
  
  if(type=="response") {
    exp.eta <- exp(eta)
    sum.exp.eta <- rowSums(exp.eta)
    p <- exp.eta/sum.exp.eta
    
    if(se.fit){
      p.long <- as.vector(t(p))
      s <- rep(1:nrow(X),each=nrow(D))
      
      wX <- p.long*(XD - rowsum(p.long*XD,s)[s,,drop=FALSE])
      se.p.long <- sqrt(rowSums(wX * (wX %*% V)))
      se.p <- rspmat(se.p.long)
      rownames(se.p) <- rownames(p)
      if(is.null(na.act))
        list(fit=p,se.fit=se.p)
      else
        list(fit=napredict(na.act,p),
             se.fit=napredict(na.act,se.p))
    }
    else {
      if(is.null(na.act)) p
      else napredict(na.act,p)
    }
  }
  else if(se.fit) {
    se.eta <- sqrt(rowSums(XD * (XD %*% V)))
    se.eta <- rspmat(se.eta)
    eta <- eta[,-1,drop=FALSE]
    se.eta <- se.eta[,-1,drop=FALSE]
    if(is.null(na.act))
        list(fit=eta,se.fit=se.eta) 
    else
      list(fit=napredict(na.act,eta),
           se.fit=napredict(na.act,se.eta))
  }
  else {
      eta <- eta[,-1,drop=FALSE]
      if(is.null(na.act)) eta
      else napredict(na.act,eta)
  }
}

Reproducible example of what I would like to achieve:

# data=SARS-CoV2 coronavirus variants (variant) through time (collection_date_num)
# in India, count=actual count (nr of sequenced genomes)
dat = read.csv("https://www.dropbox.com/s/u27cn44p5srievq/dat.csv?dl=1")
dat$collection_date = as.Date(dat$collection_date)
dat$collection_date_num = as.numeric(dat$collection_date) # numeric version of date, to convert back to date: as.Date(dat$collection_date_num, origin="1970-01-01")
dat$variant = factor(dat$variant)

# 1. multinom::net multinomial fit ####
library(nnet)
library(splines)
set.seed(1)
fit_nnet = nnet::multinom(variant ~ ns(collection_date_num, df=2), 
                          weights=count, data=dat)
summary(fit_nnet)

# 2. predicted probabilities & 95% CLs at maximum date calculated using emmeans: works, but slow for large models ####
library(emmeans)
multinom_emmeans = emmeans(fit_nnet, ~ variant,  
                       mode = "prob",
                       at=list(collection_date_num = 
                                 max(data_agbyweek1$collection_date_num)))
multinom_emmeans
# variant               prob       SE df lower.CL upper.CL
# Alpha             0.00e+00 0.00e+00 33 0.00e+00 0.00e+00
# Beta              0.00e+00 0.00e+00 33 0.00e+00 0.00e+00
# Delta             7.73e-06 1.17e-06 33 5.34e-06 1.01e-05
# Omicron (BA.1)    1.82e-04 6.42e-05 33 5.14e-05 3.13e-04
# Omicron (BA.2)    1.76e-01 7.45e-03 33 1.61e-01 1.91e-01
# Omicron (BA.2.74) 9.03e-02 7.98e-03 33 7.41e-02 1.07e-01
# Omicron (BA.2.75) 1.68e-01 1.90e-02 33 1.30e-01 2.07e-01
# Omicron (BA.2.76) 2.89e-01 1.35e-02 33 2.62e-01 3.16e-01
# Omicron (BA.3)    1.34e-02 2.10e-03 33 9.10e-03 1.76e-02
# Omicron (BA.4)    1.67e-02 2.47e-03 33 1.17e-02 2.17e-02
# Omicron (BA.5)    2.03e-01 1.08e-02 33 1.81e-01 2.25e-01
# Other             4.23e-02 3.15e-03 33 3.59e-02 4.87e-02
#
# Confidence level used: 0.95 


# 3. predicted probabilities & 95% CLs at maximum date calculated using marginaleffects: does not work because of lack of a predict.multinom method supporting type="link" ####

library(marginaleffects)
multinom_preds_marginaleffects = predictions(fit_nnet,
                                         newdata = datagrid(collection_date_num = 
                                                              max(data_agbyweek1$collection_date_num)),
                                         type="link", # not supported by predict.multinom
                                         transform_post = insight::link_inverse(fit_nnet))
# Error: The `type` argument for models of class `multinom` must be an element of: probs
# PS: desired output should match emmeans output above

The way to redefine a method in a package is to use assignInNamespace . However, assuming this is intended to be part of another package that will eventually be made public, it's a bit rude since you're trampling over someone else's code. In particular, if you intend to put it on CRAN, you might run into issues with convincing the CRAN reviewers that it's ok.

A better solution would be to create a wrapper method that calls the original method. For this you'll also need to create a wrapper multinom function, so that the correct package namespace is found. A sketch implementation is shown below.

multinom <- function(...)
{
    nnet::multinom(...)
}


predict.multinom <- function(*, type=c("probs", "link", "class")
{
    type <- match.arg(type)
    if(type != "link")
        return(nnet::predict.multinom(*, type=type))

    probs <- nnet::predict.multinom(*, type="probs")

    log(probs/(1 - probs))   # or whatever... I forget the actual formula
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM