简体   繁体   中英

Predict method of caret package gives error :Error in models[[1]]$trainingData$.outcome

I am new to neural network and caret package and struggling with an issue. I am training the model using train() method of caret package with method='nnet', and getting model fit without any error. But model throws error when I use predict() method to see predicted values on training data.

My training data is in data frame nnTrainingDataScaled and it looks like :

    PAYMENT_DELAY   AMT_TO_PAY NUMBER_OF_CREDIT_DAYS AVG_BASE_PRICE DELIVERY_DURATION
155  0.2258064516 0.2287152972          0.2333333333   0.7468513854     0.05882352941
158  0.2258064516 0.1564039392          0.2333333333   0.7732997481     0.05882352941
162  0.2258064516 0.4230656560          0.2333333333   0.8060453401     0.05882352941
164 -0.2258064516 0.3951407685          0.2333333333   0.7204030227     0.05882352941
166 -0.2258064516 1.0000000000          0.2333333333   0.6700251889     0.05882352941
168 -0.2258064516 0.2438498559          0.2333333333   0.7657430730     0.05882352941

I train the model as :

myGrid <- expand.grid(.decay = c(0.5, 0.1), .size = c(5, 6, 7))
neuralNetFit <- train(PAYMENT_DELAY ~ AMT_TO_PAY + NUMBER_OF_CREDIT_DAYS + AVG_BASE_PRICE + DELIVERY_DURATION, data = nnTrainingDataScaled, method = 'nnet', tuneGrid = myGrid, maxit = 1000,  trace = F, linout = TRUE) 

and the result is :

Resampling results across tuning parameters:

  decay  size  RMSE          Rsquared       RMSE SD        Rsquared SD  
  0.1    5     0.2810123029  0.08531037005  0.02361680282  0.04357379768
  0.1    6     0.2809714015  0.08556298113  0.02361410081  0.04368294247
  0.1    7     0.2809433123  0.08574371206  0.02359794076  0.04369737967
  0.5    5     0.2907021463  0.02093119653  0.02565310137  0.02134393442
  0.5    6     0.2907006528  0.02095733170  0.02565351995  0.02136547343
  0.5    7     0.2906981746  0.02097019598  0.02565722475  0.02137058017

RMSE was used to select the optimal model using  the smallest value.
The final values used for the model were size = 7 and decay = 0.1. 

But when I go to predict with this model on training data, I get below error.

fitValueOnTrainingData <- predict(neuralNetFit, nnTrainingDataScaled)

 Error in models[[1]]$trainingData$.outcome : 
  $ operator is invalid for atomic vectors 

I found below code snippet on stack exchange and predict() method works fine with this code. I am pondering that where am I going wrong in my code. Any help is appreciated. Thanks in advance!

library(caret)
y <- sin(seq(0, 20, 0.1))
te <- data.frame(y, x1=Lag(y), x2=Lag(y,2))
names(te) <- c("y", "x1", "x2")
model <- train(y ~ x1 + x2, te, method='nnet', linout=TRUE, trace = FALSE, tuneGrid=expand.grid(.size=c(1,5,10),.decay=c(0,0.001,0.1))) 
ps <- predict(model, te)

Here is dput of the training data :

structure(c(0, -0.291666666666667, -0.291666666666667, -0.291666666666667, 
-0.291666666666667, -0.291666666666667, -0.291666666666667, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -0.166666666666667, 
-0.166666666666667, -0.166666666666667, -0.166666666666667, -0.166666666666667, 
-0.166666666666667, -0.166666666666667, -0.166666666666667, 0, 
-0.333333333333333, -0.0416666666666667, 0, 0, -0.166666666666667, 
-0.0416666666666667, -0.166666666666667, 0, 0, 0, 0, 0, -0.125, 
-0.125, -0.125, 0, 0, -0.166666666666667, -0.166666666666667, 
0, 0, -0.166666666666667, -0.458333333333333, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -0.291666666666667, 0, 0, 0, 0, 
0, 0, 0, 0, 0, -0.166666666666667, 0, 0, 0, 0, 0, 0, 0, -0.166666666666667, 
-0.166666666666667, -0.166666666666667, -0.166666666666667, -0.166666666666667, 
-0.166666666666667, -0.166666666666667, -0.166666666666667, -0.166666666666667, 
0, -0.166666666666667, 0, 0.291666666666667, 0.291666666666667, 
0.291666666666667, 0.291666666666667, 0.291666666666667, 0.125, 
0.0416666666666667, 0.0416666666666667, 0.0416666666666667, 0.0416666666666667, 
0.291666666666667, 0.291666666666667, 0.291666666666667, 0.125, 
0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 
0.0833333333333333, 0.0833333333333333, 0.291666666666667, 0.166666666666667, 
0.125, 0.125, 0.125, 0.125, 0.125, 0.291666666666667, 0.0416666666666667, 
0.291666666666667, 0.291666666666667, 0.291666666666667, 0.291666666666667, 
0.291666666666667, 0.291666666666667, 0.708333333333333, 0.708333333333333, 
0.416666666666667, 0.416666666666667, 0.416666666666667, 0.708333333333333, 
0.416666666666667, 0.416666666666667, 0.416666666666667, 0.416666666666667, 
0.416666666666667, 0.416666666666667, 0.416666666666667, 0.416666666666667, 
0.708333333333333, 0.416666666666667, 1, 1, 1, 1, 1, 1, 0.708333333333333, 
1, 1, 0.416666666666667, 0.333333333333333, 0.875, 0.583333333333333, 
0.333333333333333, 0.583333333333333, 0.165416923477883, 0.453265327838473, 
0.1114256530887, 0.329838130541146, 0.165433887154539, 0.166473811938624, 
0.701185863808407, 1, 0.15895941723088, 0.510158415037425, 0.716269656660325, 
0.307855775849173, 0.363710480416463, 0.858129173771277, 0.149367743209003, 
0.319004509767736, 0.0928689329795978, 0.155315825105522, 0.0994436428107681, 
0.369230357990123, 0.527611982113088, 0.694437918855024, 0.412800276353714, 
0.0279489424175707, 0.350071170334333, 0.372184607982284, 0.00974537521339534, 
0.050521427436845, 0.0541824972896671, 0.0158672062829346, 0.00439102199983242, 
0.022531360954654, 0.0204299212525568, 0.111857969818021, 0.0905012149590841, 
0.146286522667327, 0.0458759502871745, 0.0944080010980125, 0.17924437615565, 
0.0153207702742924, 0.0570287966121996, 0.00761052220879407, 
0.0452817075534112, 0.0835887737472196, 0.123467807311139, 0.0933248960460754, 
0.322160781727344, 0.082304160778643, 0.0411019604355655, 0.0159499684629829, 
0.0178221415048221, 0.017853498604095, 0.0342671408956718, 0.0325959617196644, 
0.191417613334067, 0.183628201444174, 0.0347205337081106, 0.115618765527547, 
0.333991147016989, 0.0484893845937945, 0.0572004895819893, 0.0473744083917766, 
0.0347693685348472, 0.0340018906788709, 0.0526706738640634, 0.167987177516651, 
0.036454427082664, 0.0256927734223395, 0.0086936272607312, 0.0312100807419604, 
0.0319667635309739, 0.0333978809797603, 0.0407981564081831, 0.0169898932470688, 
0.0437431534858042, 0.0428451067246585, 0.00719259890209028, 
0.0376804382591567, 0.0604744791765729, 0.0179722443406861, 0.00571984333787583, 
0.0130512359580596, 0.0580250270776263, 0.0112366366066889, 0.0530037787874878, 
0.00749383267543397, 0.0537116267497647, 0.00373971962640843, 
0.0557256778145438, 0.0832474440108711, 0.0129566506094329, 0.0297249879583598, 
0.00721418903601594, 0.0283915401630466, 0.0226434240307444, 
0.00776370934950469, 0.0153665207961825, 0.0156672405187184, 
0.00956699958310479, 0.0067679929348857, 0.0761520007114463, 
0.0229348908387407, 0.14191503459819, 0.160810514189601, 0.198521795497223, 
0.157238375126521, 0.103231169162298, 0.339058145829017, 0.0133951359484469, 
0.0101910572637178, 0.00372069974652155, 0.20366024737153, 0.00231220053327631, 
0.0720149198106442, 0.312968011132284, 0.160259965774497, 0.0231852335821168, 
0.496270818415151, 0.028445001447053, 0.0354823570052017, 0.0614773923025004, 
0.0979518673666668, 0.0185628887187952, 0.00765678678149191, 
0.179498317254681, 0.176430976084814, 0.0769683133941593, 0.0248502441484311, 
0.0488415093971058, 0.0052001379712368, 0.0490795149210958, 0.0436537086452551, 
0.279442645552201, 0.137631963267985, 0.00221967138788064, 0.0479696792271555, 
0.00367803352947799, 0.319617258330579, 0.31295978631936, 0.141324904270888, 
0.328198308464412, 0.169229638318992, 0.16465767043483, 0.0247664538667673, 
0.007875772425595, 0.0492100838262652, 0.0617616623991882, 0.0840735236589314, 
0.0137441764469117, 0.0275356455681367, 0.0651328075964372, 0.45743325178774, 
0.052408507952109, 0.0266401690610297, 0.0352762226312924, 0.0395037764742592, 
0.0613617308707558, 0.00741980935911744, 0.0742895946349545, 
0.0298180311545633, 0.0253072353165242, 0.0106161772817302, 0.0219751579806645, 
0.0286305737886521, 0.0187006543352732, 0.00480637505249744, 
0.0234263234109533, 0.0233014090646691, 0.0451686163757053, 0.25109325755539, 
0.050316835215359, 0.0432640581329777, 0.173760996189341, 0.318880623523068, 
0.233333333333333, 0.233333333333333, 0.233333333333333, 0.233333333333333, 
0.233333333333333, 0.233333333333333, 0.233333333333333, 0.233333333333333, 
0.233333333333333, 0.233333333333333, 0.233333333333333, 0.233333333333333, 
0.233333333333333, 0.233333333333333, 0.233333333333333, 0.233333333333333, 
0, 0.233333333333333, 0, 0.233333333333333, 0.233333333333333, 
0.233333333333333, 0.233333333333333, 0.233333333333333, 0, 0.233333333333333, 
1, 1, 1, 1, 1, 1, 1, 1, 0.233333333333333, 1, 1, 1, 0.233333333333333, 
1, 0.233333333333333, 1, 1, 1, 1, 1, 0.233333333333333, 1, 1, 
1, 1, 1, 1, 1, 0.233333333333333, 0.233333333333333, 1, 1, 0.233333333333333, 
1, 1, 1, 1, 1, 1, 0.233333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 0.233333333333333, 1, 0.233333333333333, 0.233333333333333, 
0.233333333333333, 0.233333333333333, 0.233333333333333, 0.233333333333333, 
1, 0.233333333333333, 0.233333333333333, 0.233333333333333, 0.233333333333333, 
0.233333333333333, 0.233333333333333, 0.233333333333333, 1, 0.233333333333333, 
1, 1, 1, 1, 1, 1, 0.233333333333333, 0.233333333333333, 1, 1, 
1, 1, 1, 1, 0.233333333333333, 0.233333333333333, 1, 0.233333333333333, 
1, 0.233333333333333, 0.233333333333333, 0.233333333333333, 0.233333333333333, 
0.233333333333333, 0.233333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 
0.233333333333333, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 0.233333333333333, 1, 1, 0, 0.233333333333333, 0.801007556675063, 
0.763224181360202, 0.765743073047859, 0.821158690176322, 0.82367758186398, 
0.829974811083123, 0.690680100755667, 0.702770780856423, 0.763224181360202, 
0.71536523929471, 0.714760705289673, 0.801007556675063, 0.780856423173804, 
0.73551637279597, 0.780856423173804, 0.779596977329975, 0.869017632241814, 
0.698992443324937, 0.717884130982368, 0.861460957178841, 0.841309823677582, 
0.826196473551637, 0.705289672544081, 0.818639798488665, 0.750629722921914, 
0.881612090680101, 0.748110831234257, 0.748110831234257, 0.797229219143577, 
0.86272040302267, 0.712846347607053, 0.712846347607053, 0.748110831234257, 
0.86272040302267, 0.879093198992443, 0.931989924433249, 0.86272040302267, 
0.86272040302267, 0.934508816120907, 0.842569269521411, 0.817380352644836, 
0.842569269521411, 0.797229219143577, 0.931989924433249, 0.931989924433249, 
0.797229219143577, 0.792191435768262, 0.931989924433249, 0.931989924433249, 
0.931989924433249, 0.98992443324937, 0.98992443324937, 0.98992443324937, 
0.98992443324937, 0.881612090680101, 0.919395465994962, 0.98992443324937, 
0.98992443324937, 0.906801007556675, 0.98992443324937, 0.98992443324937, 
0.889168765743073, 0.98992443324937, 0.973551637279597, 0.98992443324937, 
0.738035264483627, 0.98992443324937, 0.973551637279597, 0.98992443324937, 
0.889168765743073, 1, 0.973551637279597, 0.973551637279597, 0.973551637279597, 
1, 0.973551637279597, 1, 0.937027707808564, 0.973551637279597, 
1, 0.826196473551637, 0.826196473551637, 0.826196473551637, 0.843828715365239, 
1, 0.843828715365239, 0.806045340050378, 0.826196473551637, 0.843828715365239, 
0.826196473551637, 0.806045340050378, 0.806045340050378, 0.806045340050378, 
0.843828715365239, 0.826196473551637, 0.843828715365239, 0.843828715365239, 
0.937027707808564, 0.937027707808564, 1, 0.649874055415617, 0.843828715365239, 
0.748110831234257, 0.782115869017632, 0.773299748110831, 0.795969773299748, 
0.773299748110831, 0.843828715365239, 0.842569269521411, 0.773299748110831, 
0.773299748110831, 0.773299748110831, 0.773299748110831, 0.793450881612091, 
0.761964735516373, 0.685138539042821, 0.797229219143577, 0.836272040302267, 
0.797229219143577, 0.804785894206549, 0.86272040302267, 0.86272040302267, 
0.797229219143577, 0.842569269521411, 0.881612090680101, 0.851385390428212, 
0.86272040302267, 0.804785894206549, 0.931989924433249, 0.806045340050378, 
0.826196473551637, 1, 0.705289672544081, 0.619647355163728, 0.973551637279597, 
0.649874055415617, 0.748110831234257, 0.806045340050378, 0.779596977329975, 
0.712846347607053, 0.802267002518892, 0.856423173803526, 0.818639798488665, 
0.842569269521411, 0.842569269521411, 0.748110831234257, 0.797229219143577, 
0.804785894206549, 0.748110831234257, 0.86272040302267, 0.804785894206549, 
0.758186397984887, 0.797229219143577, 0.748110831234257, 0.804785894206549, 
0.804785894206549, 0.86272040302267, 0.842569269521411, 0.86272040302267, 
0.748110831234257, 0.804785894206549, 0.797229219143577, 0.797229219143577, 
0.748110831234257, 0.712846347607053, 0.748110831234257, 0.86272040302267, 
0.797229219143577, 0.712846347607053, 0.780856423173804, 0.931989924433249, 
0.826196473551637, 0.770780856423174, 0.817380352644836, 0.1, 
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 
0.9, 0.1, 0.1, 0.1, 0.5, 0.2, 0.4, 0.1, 1, 0.3, 0.5, 0.1, 0.4, 
0.1, 0.1, 0.1, 0.5, 0.5, 0.8, 0.1, 0.4, 0.3, 0.1, 0.1, 0.1, 0.1, 
0.3, 0.1, 0.1, 0.4, 0.7, 0.1, 0.4, 0.4, 0.1, 0.2, 0.4, 0.3, 0.1, 
0.2, 0.2, 0.1, 0.1, 0.1, 0.4, 0.5, 0.1, 0.2, 0.5, 0.1, 0.1, 0.5, 
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.5, 0.3, 0.1, 0.1, 
0.1, 0.3, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 0.1, 0.1, 0.1, 
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 
0.1, 0.2, 0.1, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 0.1, 0.1, 0.2, 
0.3, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 
0.1, 0.2, 0.5, 0.5, 0.1, 0.1, 0.3, 0.1, 0.2, 0.2, 0.4, 0.2, 0.3, 
0.1, 0.2, 0.3, 0.1, 0.1, 0.1, 0.5, 0.4, 0.1, 0.2, 0.2, 0.5, 0.1, 
0.1, 0.1, 0.1, 0.1), .Dim = c(174L, 5L), .Dimnames = list(c("159", 
"165", "167", "170", "171", "172", "173", "183", "184", "185", 
"186", "190", "191", "192", "193", "194", "195", "199", "200", 
"205", "209", "210", "225", "232", "237", "251", "254", "255", 
"256", "257", "258", "259", "260", "262", "265", "269", "270", 
"272", "274", "276", "277", "278", "281", "282", "283", "287", 
"288", "290", "291", "292", "295", "296", "298", "299", "300", 
"301", "302", "303", "305", "308", "309", "310", "311", "312", 
"316", "318", "320", "322", "323", "324", "325", "326", "329", 
"332", "333", "340", "343", "344", "346", "347", "348", "350", 
"351", "355", "358", "359", "361", "362", "364", "365", "366", 
"369", "370", "371", "373", "375", "376", "378", "379", "380", 
"381", "383", "389", "156", "157", "160", "161", "163", "175", 
"178", "179", "180", "182", "189", "198", "204", "221", "222", 
"223", "224", "236", "244", "245", "273", "279", "280", "286", 
"289", "304", "307", "336", "337", "353", "354", "382", "385", 
"386", "391", "394", "395", "396", "397", "400", "174", "176", 
"177", "196", "197", "201", "203", "207", "208", "215", "216", 
"217", "218", "228", "230", "234", "238", "240", "241", "242", 
"243", "246", "248", "250", "252", "253", "268", "285", "356", 
"387", "399"), c("PAYMENT_DELAY", "AMT_TO_PAY", "NUMBER_OF_CREDIT_DAYS", 
"AVG_BASE_PRICE", "DELIVERY_DURATION")))

ok, it is a bug that will be fixed in the next release.

In the meantime, convert nnTrainingDataScaled to a data frame before using train and you can get predictions.

Thanks,

Max

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM