繁体   English   中英

带有 R 包 nnet 的神经网络:由于过度拟合导致的垃圾预测?

[英]neural network with R package nnet: rubbish prediction due to overfitting?

试图弄清楚我是否有 R 问题或一般的神经网络问题。

假设我有这些数据:

set.seed(123)
n = 1e3
x = rnorm(n)
y = 1 + 3*sin(x/2) + 15*cos(pi*x) + rnorm(n = length(x))
df = data.frame(y,x)
df$train = sample(c(TRUE, FALSE), length(y), replace=TRUE, prob=c(0.7,0.3))
df_train = subset(df, train = TRUE)
df_test = subset(df, train = FALSE)

然后你训练神经网络,它看起来不错:

library(nnet)
nn = nnet(y~x, data = df_train, size = 60, linout=TRUE) 
yhat_nn = predict(nn, newdata = df_test)
plot(df_test$x,df_test$y)
points(df_test$x, yhat_nn, col = 'blue')

在此处输入图像描述

好的,所以我想,让我们生成新数据,然后使用训练好的网络进行预测。 但预测还差得很远:

x2 = rnorm(n)
y2 = 1 + 3*sin(x2/2) + 15*cos(pi*x2) + rnorm(n = length(x2))
df2 = data.frame(y2,x2)
plot(df2$x, df2$y)
points(df2$x, predict(nn, newdata = df2), col = 'blue')

在此处输入图像描述

这是因为我对训练集过度拟合了吗? 我认为通过将原始数据拆分为测试训练可以避免过度拟合。

致命问题是您的新数据框df2没有正确的变量名称。 结果, predict.nnet找不到正确的值。

names(df)
#[1] "y"     "x"     "train"

names(df2)
#[1] "y2"     "x2"

predict构建数据框时要小心。

## the right way
df2 <- data.frame(y = y2, x = x2)

## and it solves the mystery
plot(df2$x, df2$y)
points(df2$x, predict(nn, newdata = df2), col = 'blue')

对 df2 的预测


另一个小问题是您subset的使用。 它应该是

## not train = TRUE or train = FALSE
df_train <- subset(df, train == TRUE) ## or simply subset(df, train)
df_test <- subset(df, train == FALSE) ## or simply subset(df, !train)

这有一个有趣的效果:

nrow(subset(df, train == TRUE))
#[1] 718

nrow(subset(df, train = TRUE))  ## oops!!
#[1] 1000

完整的 R 会话

set.seed(123)
n = 1e3
x = rnorm(n)
y = 1 + 3*sin(x/2) + 15*cos(pi*x) + rnorm(n = length(x))
df = data.frame(y,x)
df$train = sample(c(TRUE, FALSE), length(y), replace=TRUE, prob=c(0.7,0.3))
df_train = subset(df, train == TRUE)  ## fixed
df_test = subset(df, train == FALSE)  ## fixed
library(nnet)
nn = nnet(y~x, data = df_train, size = 60, linout=TRUE) 
yhat_nn = predict(nn, newdata = df_test)
plot(df_test$x,df_test$y)
points(df_test$x, yhat_nn, col = 'blue')
x2 = rnorm(n)
y2 = 1 + 3*sin(x2/2) + 15*cos(pi*x2) + rnorm(n = length(x2))
df2 = data.frame(y = y2, x = x2)  ## fixed
plot(df2$x, df2$y)
points(df2$x, predict(nn, newdata = df2), col = 'blue')

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM