简体   繁体   English

在 R 中绘制 cv.glmnet

[英]Plotting cv.glmnet in R

Using R, I am trying to modify a standard plot which I get from performing a ridge regression using cv.glmnet.使用 R,我试图修改我从使用 cv.glmnet 执行岭回归得到的标准图。

I perform a ridge regression我执行岭回归

lam = 10 ^ seq (-2,3, length =100)    
cvfit = cv.glmnet(xTrain, yTrain, alpha = 0, lambda = lam)

I can plot the coefficients against log lambda by doing the following我可以通过执行以下操作来绘制对数 lambda 的系数

plot(cvfit $glmnet.fit, "lambda")

在此处输入图片说明

How can plot the coefficients against the actual lambda values (not log lambda) and label the each predictor on the plot?如何根据实际 lambda 值(不是 log lambda)绘制系数并在图中标记每个预测变量?

You can do it like this, the values are stored under $beta and $lambda , under glmnet.fit :您可以这样做,值存储在$beta$lambda下,在glmnet.fit下:

library(glmnet)

xTrain = as.matrix(mtcars[,-1])
yTrain = mtcars[,1]

lam = 10 ^ seq (-2,3, length =30)    
cvfit = cv.glmnet(xTrain, yTrain, alpha = 0, lambda = lam)

betas = as.matrix(cvfit$glmnet.fit$beta)
lambdas = cvfit$lambda
names(lambdas) = colnames(betas)

Using a ggplot solution, we try to pivot it long and plot using a log10 x scale and ggrepel to add the labels:使用 ggplot 解决方案,我们尝试将其旋转很长并使用 log10 x 比例和 ggrepel 进行绘图以添加标签:

library(ggplot2)
library(tidyr)
library(dplyr)
library(ggrepel)

as.data.frame(betas) %>% 
tibble::rownames_to_column("variable") %>% 
pivot_longer(-variable) %>% 
mutate(lambda=lambdas[name]) %>% 
ggplot(aes(x=lambda,y=value,col=variable)) + 
geom_line() + 
geom_label_repel(data=~subset(.x,lambda==min(lambda)),
aes(label=variable),nudge_x=-0.5) +
scale_x_log10()

在此处输入图片说明

In base R, maybe something like this, I think downside is you can't see labels very well:在基础 R 中,也许是这样的,我认为缺点是你不能很好地看到标签:

pal = RColorBrewer::brewer.pal(nrow(betas),"Set3")
plot(NULL,xlim=range(log10(lambdas))+c(-0.3,0.3),
ylim=range(betas),xlab="lambda",ylab="coef",xaxt="n")
for(i in 1:nrow(betas)){
    lines(log10(lambdas),betas[i,],col=pal[i])
}

axis(side=1,at=(-2):2,10^((-2):2))
text(x=log10(min(lambdas)) - 0.1,y = betas[,ncol(betas)],
labels=rownames(betas),cex=0.5)

legend("topright",fill=pal,rownames(betas))

在此处输入图片说明

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM