简体   繁体   English

为 model 选择实现 LOOCV(不使用插入符号包)

[英]Implement LOOCV for model selection (Without using the caret package)

I have this dataset where I'm trying to find the best model using R我有这个数据集,我正在尝试使用 R 找到最好的 model

Dataset:数据集:

structure(list(V1 = c(1.43359910241166, 0.411971077467806, 0.236361845246534, 
-0.289263426819727, -1.23202861459847, -0.738796384986188, 0.200420172968439, 
1.55763841132305, 0.306848974278087, -1.06336757529454, 0.208462982445177, 
-0.161933544137143, -0.529226737265933, 1.06311471300635, 0.154281146875831, 
0.609577869238014, -0.13720552696616, 0.920650581183744, 1.18282854178987, 
-0.792945405521446, -0.609722647650392, -0.21688852962299, -3.06426186175807, 
0.5498848363865), V2 = c(-0.322161064448354, -0.203202315321523, 
-1.37681357972322, -2.09183896169083, -1.73416522569493, -0.167163678879473, 
-0.496644140621754, -0.378640254832213, 1.71897857982319, 0.987886990249993, 
-0.464176577243306, 0.313599912560739, 0.279305189424942, 0.621879051693468, 
-1.35413705469938, -0.904307866112488, 0.563960402008738, 0.942178870082166, 
1.05504527675313, 1.72684177309, 0.487583880103759, 0.366982237506534, 
0.341207392409481, 0.0878011635613361), V3 = c(-3.06259779143185, 
0.113156471002083, -0.596111339640452, -0.0549465535711572, -0.941898864240695, 
-0.653015082018507, -0.169956676284042, -0.35411953808696, 0.713862293279259, 
1.20019049753438, 0.295042002436139, -0.248609439893179, 1.9312167684667, 
0.674670687298312, 0.224140747830105, -0.59349261052001, 0.0558808922143246, 
0.749007982254512, 1.04584894162381, 0.280651184742914, -0.313568542992107, 
-1.54267082673779, 0.397265080266878, 0.850053716467332), V4 = c(-2.72697312636474, 
0.851743869193346, -0.0599187094978506, 0.341978048955579, -0.484015693411596, 
-0.131475393689722, -0.021866557862478, -1.8191792655517, 1.74883985589495, 
-0.446343374015597, 0.0107633789594956, 0.55528371030783, 0.31132242799237, 
0.0710046563366782, 0.701388784100771, 1.56870481640847, -0.841113890934613, 
-0.881987858407386, 1.37693978208629, -0.488560120797117, 0.366895195216852, 
0.0627972059134885, -0.655416452787133, 0.589188711953821), V5 = c(-1.79836984688233, 
0.50295466271361, -1.17227869532777, 0.661412408202374, 0.853890060320874, 
0.349725611664228, -0.308069063888987, -0.433246608902138, -0.178767449882736, 
1.34125510863996, 0.206474174580616, -0.657831069822233, 0.215632332747088, 
0.573672331330443, -0.202823754124207, 0.609758501891791, 0.222044482387977, 
2.56037433110525, -1.29345283990688, 0.174550400877521, -0.174265216769768, 
0.55419775558349, -0.458225457879011, -2.14861215865916), V6 = c(-0.18026818728965, 
-0.480816154309526, -0.50256960223903, -1.31874854057412, -0.896086924318379, 
-1.79382217103909, -1.60213450587948, -0.481119812364401, 0.377075792056211, 
1.34981730088023, 0.0611706096060544, 0.83874651540465, 0.58899516399665, 
1.24066391945654, -1.08080170411743, 0.597620326597847, -1.21365483260366, 
0.230893469563153, -0.576677068566099, 1.31703258659203, 0.35136844419016, 
0.925208426922233, 1.73348977742475, 0.514617170610343), V7 = c(0.692646184527114, 
1.64958468445801, -0.722861261417701, -0.411292490473929, -1.73926867251488, 
0.479847732965793, 0.224291785874008, -0.650661070391403, -0.20779505689401, 
-0.900990363217965, 0.712570690351891, 0.0291624484927884, 0.613871305452367, 
-0.901767959624604, -0.184130922600279, 2.60941994159236, 0.0144701586285878, 
1.00941096184201, -1.07148389565784, -0.439790917550134, -0.786567592396622, 
0.926243735906836, -1.39392614240757, 0.449016715055174), V8 = c(-0.218730876718155, 
0.279536175230915, -0.860839531512879, 1.62382620633742, -0.656202640703168, 
-3.05801703213563, 0.243884147081474, 0.926579301241956, 0.58184138659717, 
-0.0814078168437784, -0.0668035158044736, 0.00153834639170001, 
0.806767895958209, 0.834326360087515, -0.0790896439523125, 0.07028192584928, 
-0.619273530317688, 1.07556660504801, -1.13473924521572, 0.668145147063421, 
0.758090513962191, 0.456430947715887, -1.73160959029873, 0.179898464937389
), V9 = c(2.56974590352874, -0.263155790779132, 0.646658371629822, 
-0.752843366448987, 0.200047856906594, 0.659371008337854, 1.24620285734473, 
0.94634794321528, -1.3304334794271, 1.33090401796431, -0.819840444239054, 
0.272969704571894, -0.486961950780986, 0.169639870524667, -0.451658048721127, 
-1.04537018765646, -1.16107891054576, -1.20995090654021, -0.839823653138378, 
0.62253221198192, 0.622634591405887, -0.547608828939565, 0.786557248787584, 
-1.16488601898254), V10 = c(-2.26412916115509, 0.67348993363598, 
-0.342027192999345, 0.249815496496033, 0.30352488488975, -0.744451635640458, 
1.58487417838063, -1.01570448604582, -0.541105970352036, 1.13647671257197, 
-0.54886598448313, -0.962789161396563, -0.538065955333129, 0.0781727823942247, 
0.0970193660300894, 1.18927210039089, -0.6957686086705, -0.386785336508124, 
-0.35257548033064, 2.31937096293864, -0.549132531058022, -0.0974568592721698, 
1.43853645612397, -0.0316945106071529), V11 = c(-1.86095070927053, 
0.573330283491408, -1.03183858717977, -1.83745190916475, -0.077180684913356, 
-0.94533768863225, -0.641638632478328, 0.154349543995556, 1.89664953662371, 
1.3494700201932, 1.04343452008192, 1.03948878970461, 0.394740150081754, 
1.24869842481551, 0.33270007318232, 0.373677276693529, 0.670774298645023, 
-0.0191045174843475, 0.0901593335518681, -0.813757209813031, 
-0.527741614949631, -1.55637393322463, -0.0817683516977811, 0.225671587747989
), V12 = c(0.235155165117673, 0.0334071835637513, 0.141983465568844, 
0.441692874434554, 0.0707526888389656, 0.332161357520943, 0.0735800395703528, 
-0.281305763416249, 0.16538364649173, -1.15487983901285, 1.56899928098857, 
-0.567750194144175, 0.541218236160627, 1.48159680904495, -0.568523352759803, 
-0.0545712227404042, -2.93340050534491, 0.662421496450859, 1.11729205722267, 
-0.581175560009803, 0.792548304722282, 0.955149345977461, -0.821090667653583, 
-1.65064484659245), V13 = c(-1.97412125867671, 0.44572205242864, 
-0.274712915255066, -1.44692140049933, -1.18035700830368, -0.260286573948736, 
-0.95815595797825, -0.242760674716397, 0.477953228907608, 0.992878959448502, 
0.48518262700317, -0.882424015844636, 2.03856721097186, 0.782640940939034, 
0.00789969362112054, -0.295894328060507, 1.27922468162261, 0.51472928905797, 
0.0447383908218823, 0.165638463053774, -0.263332324321804, -1.15204704327981, 
-0.258342890933598, 1.95418085394235), V14 = c(-0.181993529177506, 
1.39403983793056, -0.152733307069606, -1.52421030170283, -0.924924418962197, 
-0.364387222675804, 1.10283509955152, 0.0727783277608945, -1.77522562543095, 
1.08664918075833, -1.04803884297856, -0.940631906527986, 1.12617755875177, 
1.21705368328955, -0.279102677856877, 0.343713803473868, 1.26542530994074, 
-0.774396836280874, 0.417125600747737, 1.49096714826284, 0.284166748008431, 
-1.53295609357739, 0.105608954195959, -0.407940490431605), V15 = c(-1.46474265513464, 
1.19486941463858, 0.244933071673175, -0.459011700723317, 0.241718140420906, 
0.282959623977014, 0.00585677416957126, -2.03400384857495, 0.537918956631718, 
-1.04030075327707, -0.557219563096931, -0.252427064540924, 0.547956268292219, 
-0.526158422645334, 0.251554548033225, -0.745912076395139, -0.0351666299711204, 
1.15204026955591, 0.842246979246097, 1.52268303136091, -1.90156582122334, 
-0.142035061237368, 0.385224459566802, 1.94858205925399), V16 = c(0.828548104520814, 
0.713189024971904, 0.774573684318552, -0.425568343697551, 0.259608074896051, 
-1.22029633555545, -0.344755278537263, 0.973749897026122, -0.474553098183039, 
0.0257155566445092, -0.476287023663646, 0.974669054546108, -1.77164686907544, 
1.56028342699847, 1.24959541751606, -0.574201649578301, 1.2099741843225, 
-0.0750690376790856, -0.0110241372862062, -0.984530244128971, 
-2.52086075001167, 0.0287667805602271, 0.731343831738835, -0.451224270663529
), V17 = c(-0.681074029216176, -0.0390433509889875, 0.0328512523391066, 
1.12428796011696, 0.176765286103444, -0.222850967042728, 0.988520019729737, 
2.09179105565111, 0.116819106946508, 0.51447781508645, 1.87648378755979, 
-1.08036997332246, -0.418517756914466, 0.291253915397003, -0.355756145391065, 
0.874359244531183, -2.35192438381252, -0.200559130397419, -1.29305021151605, 
-0.216777649470054, -1.43207151780606, -0.392317470556723, 0.447601162558867, 
0.149101980414553), V18 = c(-1.96475300593026, 0.422711683040055, 
-1.12996029903421, -2.33587910613298, 0.179352498545959, -0.600058127770143, 
-1.35077156778998, -0.727365308346169, 1.43052873254504, 1.07048786910024, 
1.15649152054786, 0.702163956193049, 0.599458156020645, 0.489172517239038, 
0.957116387643539, 0.335186798948586, -0.598777825023964, 0.10012893280699, 
0.0822063408722808, 0.393896776121708, 0.968441995451939, -0.625513747288306, 
-0.437871585012806, 0.883606407251895), V19 = c(0.203243289070699, 
0.206783154660488, 0.0730205054389099, 0.151752499129077, 0.339065300597841, 
0.198750153846351, 0.246574181097875, 0.219716854159337, 0.112571755773366, 
0.108437458425644, 0.159923853880819, 0.198217376539615, 1.27794667790059, 
0.0628191359027579, -0.023668700184257, 0.0103470645871769, -4.55192891533295, 
0.0932248108210876, 0.0372915017676821, 0.103290843005291, 0.1485089149749, 
0.167015138770557, 0.258108289841612, 0.198988855325523), V20 = c(-0.6885610185506, 
0.215106818871655, -1.26229703607397, -1.15415874394993, -0.770942786330788, 
-1.07811513531511, -1.34581518035362, 0.296281823344214, -0.525449013409778, 
1.52659228597052, 1.66011376586839, 0.204981756466606, 2.25710524990656, 
0.850893107617607, 0.181598239123184, 0.0790398588000734, -0.0665218787774753, 
0.411298611581292, 0.0839458342094344, -0.122405563089466, -1.6897393933796, 
1.24061257187769, -0.157685318761091, -0.145878855645788), outcome_var = c(-3, 
4, 1, -1, -1, -3, -1, -3, 3, 2, -2, -3, 1, 0, 0, 0, 3, 0, 2, 
2, 1, -3, 1, 0)), class = "data.frame", row.names = c(NA, -24L
)) 

Code:代码:

train.control <- trainControl(method = "LOOCV")

step.model <- train(outcome_var ~., data = total,
                method = "leapSeq", 
                tuneGrid = data.frame(nvmax = 1:5),
                trControl = train.control
)

step.model$results

summary(step.model$finalModel)

Result:结果:

20 Variables  (and intercept)
Forced in Forced out
V1      FALSE      FALSE
V2      FALSE      FALSE
V3      FALSE      FALSE
V4      FALSE      FALSE
V5      FALSE      FALSE
V6      FALSE      FALSE
V7      FALSE      FALSE
V8      FALSE      FALSE
V9      FALSE      FALSE
V10     FALSE      FALSE
V11     FALSE      FALSE
V12     FALSE      FALSE
V13     FALSE      FALSE
V14     FALSE      FALSE
V15     FALSE      FALSE
V16     FALSE      FALSE
V17     FALSE      FALSE
V18     FALSE      FALSE
V19     FALSE      FALSE
V20     FALSE      FALSE
1 subsets of each size up to 3
Selection Algorithm: 'sequential replacement'
         V1  V2  V3  V4  V5  V6  V7  V8  V9  V10 V11 V12 V13 V14 V15 V16 V17 V18 V19 V20
1  ( 1 ) " " " " "*" " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " "
2  ( 1 ) " " " " "*" " " " " " " " " " " " " " " " " " " " " " " " " " " "*" " " " " " "
3  ( 1 ) " " " " " " " " " " " " " " " " " " "*" " " " " "*" " " " " " " "*" " " " " " "

Which gives me the output I'm looking for, but now I'm trying to make my own LOOCV function and not use the caret package for it.这给了我我正在寻找的 output,但现在我正在尝试制作自己的 LOOCV function,而不是使用插入符号 ZEFE90A8E604A7C840E88D03A67F6B7D8。 And I'm not getting the same results,我没有得到相同的结果,

loocv = function(fit) {
  n = length(fit$residuals)
  yvar = fit$model[, 1]
  index = 1:n
  e = rep(NA, n)
  for (i in index) {
    refit = update(fit, subset = index != i)
    pred = predict(refit, dplyr::slice(fit$model, i))
    e[i] = yvar[i] - pred
  }
  return(mean(e^2))
}

How can I use LOOCV without using the caret package and find the best fitting model?如何在不使用插入符号 package 的情况下使用 LOOCV 并找到最合适的 model?

For cross-validation like LOOCV , the model should be constructed from scratch for each testing fold.对于像LOOCV这样的交叉验证,model 应该为每个测试折叠从头开始构建。 By trial and error, I believe caret uses leaps::regsubsets for stepwise model selection.通过反复试验,我相信caret使用leaps::regsubsets进行逐步 model 选择。

library(leaps)

nvmax = 3 #number of max variables
pred = rep(NA, nrow(total))
for (i in seq(nrow(total))) #LOOCV
  {#train a new model
   tem = regsubsets(x=total[-i,1:20], 
                    y=total[-i,21], 
                    nvmax=nvmax, 
                    method="seqrep")
  coef(tem, nvmax) #best coef chosen
  fit = lm(outcome_var ~ ., 
           data = total[-i,
                  c(which(summary(tem)$which[nvmax,-1]), 
                  21)])

  #predict the hold-out data
  pred[i] = predict(fit, newdata=total[i,])
  }

RMSE(pred, total[,'outcome_var'])
#1.945036

MAE(pred, total[,'outcome_var'])
#1.442353

Results from caret:插入符号的结果:

step.model$results
# nvmax     RMSE    Rsquared      MAE
#     3 1.945036 0.238655497 1.442353

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM