[英]Implement LOOCV for model selection (Without using the caret package)
I have this dataset where I'm trying to find the best model using R我有这个数据集,我正在尝试使用 R 找到最好的 model
Dataset:数据集:
structure(list(V1 = c(1.43359910241166, 0.411971077467806, 0.236361845246534,
-0.289263426819727, -1.23202861459847, -0.738796384986188, 0.200420172968439,
1.55763841132305, 0.306848974278087, -1.06336757529454, 0.208462982445177,
-0.161933544137143, -0.529226737265933, 1.06311471300635, 0.154281146875831,
0.609577869238014, -0.13720552696616, 0.920650581183744, 1.18282854178987,
-0.792945405521446, -0.609722647650392, -0.21688852962299, -3.06426186175807,
0.5498848363865), V2 = c(-0.322161064448354, -0.203202315321523,
-1.37681357972322, -2.09183896169083, -1.73416522569493, -0.167163678879473,
-0.496644140621754, -0.378640254832213, 1.71897857982319, 0.987886990249993,
-0.464176577243306, 0.313599912560739, 0.279305189424942, 0.621879051693468,
-1.35413705469938, -0.904307866112488, 0.563960402008738, 0.942178870082166,
1.05504527675313, 1.72684177309, 0.487583880103759, 0.366982237506534,
0.341207392409481, 0.0878011635613361), V3 = c(-3.06259779143185,
0.113156471002083, -0.596111339640452, -0.0549465535711572, -0.941898864240695,
-0.653015082018507, -0.169956676284042, -0.35411953808696, 0.713862293279259,
1.20019049753438, 0.295042002436139, -0.248609439893179, 1.9312167684667,
0.674670687298312, 0.224140747830105, -0.59349261052001, 0.0558808922143246,
0.749007982254512, 1.04584894162381, 0.280651184742914, -0.313568542992107,
-1.54267082673779, 0.397265080266878, 0.850053716467332), V4 = c(-2.72697312636474,
0.851743869193346, -0.0599187094978506, 0.341978048955579, -0.484015693411596,
-0.131475393689722, -0.021866557862478, -1.8191792655517, 1.74883985589495,
-0.446343374015597, 0.0107633789594956, 0.55528371030783, 0.31132242799237,
0.0710046563366782, 0.701388784100771, 1.56870481640847, -0.841113890934613,
-0.881987858407386, 1.37693978208629, -0.488560120797117, 0.366895195216852,
0.0627972059134885, -0.655416452787133, 0.589188711953821), V5 = c(-1.79836984688233,
0.50295466271361, -1.17227869532777, 0.661412408202374, 0.853890060320874,
0.349725611664228, -0.308069063888987, -0.433246608902138, -0.178767449882736,
1.34125510863996, 0.206474174580616, -0.657831069822233, 0.215632332747088,
0.573672331330443, -0.202823754124207, 0.609758501891791, 0.222044482387977,
2.56037433110525, -1.29345283990688, 0.174550400877521, -0.174265216769768,
0.55419775558349, -0.458225457879011, -2.14861215865916), V6 = c(-0.18026818728965,
-0.480816154309526, -0.50256960223903, -1.31874854057412, -0.896086924318379,
-1.79382217103909, -1.60213450587948, -0.481119812364401, 0.377075792056211,
1.34981730088023, 0.0611706096060544, 0.83874651540465, 0.58899516399665,
1.24066391945654, -1.08080170411743, 0.597620326597847, -1.21365483260366,
0.230893469563153, -0.576677068566099, 1.31703258659203, 0.35136844419016,
0.925208426922233, 1.73348977742475, 0.514617170610343), V7 = c(0.692646184527114,
1.64958468445801, -0.722861261417701, -0.411292490473929, -1.73926867251488,
0.479847732965793, 0.224291785874008, -0.650661070391403, -0.20779505689401,
-0.900990363217965, 0.712570690351891, 0.0291624484927884, 0.613871305452367,
-0.901767959624604, -0.184130922600279, 2.60941994159236, 0.0144701586285878,
1.00941096184201, -1.07148389565784, -0.439790917550134, -0.786567592396622,
0.926243735906836, -1.39392614240757, 0.449016715055174), V8 = c(-0.218730876718155,
0.279536175230915, -0.860839531512879, 1.62382620633742, -0.656202640703168,
-3.05801703213563, 0.243884147081474, 0.926579301241956, 0.58184138659717,
-0.0814078168437784, -0.0668035158044736, 0.00153834639170001,
0.806767895958209, 0.834326360087515, -0.0790896439523125, 0.07028192584928,
-0.619273530317688, 1.07556660504801, -1.13473924521572, 0.668145147063421,
0.758090513962191, 0.456430947715887, -1.73160959029873, 0.179898464937389
), V9 = c(2.56974590352874, -0.263155790779132, 0.646658371629822,
-0.752843366448987, 0.200047856906594, 0.659371008337854, 1.24620285734473,
0.94634794321528, -1.3304334794271, 1.33090401796431, -0.819840444239054,
0.272969704571894, -0.486961950780986, 0.169639870524667, -0.451658048721127,
-1.04537018765646, -1.16107891054576, -1.20995090654021, -0.839823653138378,
0.62253221198192, 0.622634591405887, -0.547608828939565, 0.786557248787584,
-1.16488601898254), V10 = c(-2.26412916115509, 0.67348993363598,
-0.342027192999345, 0.249815496496033, 0.30352488488975, -0.744451635640458,
1.58487417838063, -1.01570448604582, -0.541105970352036, 1.13647671257197,
-0.54886598448313, -0.962789161396563, -0.538065955333129, 0.0781727823942247,
0.0970193660300894, 1.18927210039089, -0.6957686086705, -0.386785336508124,
-0.35257548033064, 2.31937096293864, -0.549132531058022, -0.0974568592721698,
1.43853645612397, -0.0316945106071529), V11 = c(-1.86095070927053,
0.573330283491408, -1.03183858717977, -1.83745190916475, -0.077180684913356,
-0.94533768863225, -0.641638632478328, 0.154349543995556, 1.89664953662371,
1.3494700201932, 1.04343452008192, 1.03948878970461, 0.394740150081754,
1.24869842481551, 0.33270007318232, 0.373677276693529, 0.670774298645023,
-0.0191045174843475, 0.0901593335518681, -0.813757209813031,
-0.527741614949631, -1.55637393322463, -0.0817683516977811, 0.225671587747989
), V12 = c(0.235155165117673, 0.0334071835637513, 0.141983465568844,
0.441692874434554, 0.0707526888389656, 0.332161357520943, 0.0735800395703528,
-0.281305763416249, 0.16538364649173, -1.15487983901285, 1.56899928098857,
-0.567750194144175, 0.541218236160627, 1.48159680904495, -0.568523352759803,
-0.0545712227404042, -2.93340050534491, 0.662421496450859, 1.11729205722267,
-0.581175560009803, 0.792548304722282, 0.955149345977461, -0.821090667653583,
-1.65064484659245), V13 = c(-1.97412125867671, 0.44572205242864,
-0.274712915255066, -1.44692140049933, -1.18035700830368, -0.260286573948736,
-0.95815595797825, -0.242760674716397, 0.477953228907608, 0.992878959448502,
0.48518262700317, -0.882424015844636, 2.03856721097186, 0.782640940939034,
0.00789969362112054, -0.295894328060507, 1.27922468162261, 0.51472928905797,
0.0447383908218823, 0.165638463053774, -0.263332324321804, -1.15204704327981,
-0.258342890933598, 1.95418085394235), V14 = c(-0.181993529177506,
1.39403983793056, -0.152733307069606, -1.52421030170283, -0.924924418962197,
-0.364387222675804, 1.10283509955152, 0.0727783277608945, -1.77522562543095,
1.08664918075833, -1.04803884297856, -0.940631906527986, 1.12617755875177,
1.21705368328955, -0.279102677856877, 0.343713803473868, 1.26542530994074,
-0.774396836280874, 0.417125600747737, 1.49096714826284, 0.284166748008431,
-1.53295609357739, 0.105608954195959, -0.407940490431605), V15 = c(-1.46474265513464,
1.19486941463858, 0.244933071673175, -0.459011700723317, 0.241718140420906,
0.282959623977014, 0.00585677416957126, -2.03400384857495, 0.537918956631718,
-1.04030075327707, -0.557219563096931, -0.252427064540924, 0.547956268292219,
-0.526158422645334, 0.251554548033225, -0.745912076395139, -0.0351666299711204,
1.15204026955591, 0.842246979246097, 1.52268303136091, -1.90156582122334,
-0.142035061237368, 0.385224459566802, 1.94858205925399), V16 = c(0.828548104520814,
0.713189024971904, 0.774573684318552, -0.425568343697551, 0.259608074896051,
-1.22029633555545, -0.344755278537263, 0.973749897026122, -0.474553098183039,
0.0257155566445092, -0.476287023663646, 0.974669054546108, -1.77164686907544,
1.56028342699847, 1.24959541751606, -0.574201649578301, 1.2099741843225,
-0.0750690376790856, -0.0110241372862062, -0.984530244128971,
-2.52086075001167, 0.0287667805602271, 0.731343831738835, -0.451224270663529
), V17 = c(-0.681074029216176, -0.0390433509889875, 0.0328512523391066,
1.12428796011696, 0.176765286103444, -0.222850967042728, 0.988520019729737,
2.09179105565111, 0.116819106946508, 0.51447781508645, 1.87648378755979,
-1.08036997332246, -0.418517756914466, 0.291253915397003, -0.355756145391065,
0.874359244531183, -2.35192438381252, -0.200559130397419, -1.29305021151605,
-0.216777649470054, -1.43207151780606, -0.392317470556723, 0.447601162558867,
0.149101980414553), V18 = c(-1.96475300593026, 0.422711683040055,
-1.12996029903421, -2.33587910613298, 0.179352498545959, -0.600058127770143,
-1.35077156778998, -0.727365308346169, 1.43052873254504, 1.07048786910024,
1.15649152054786, 0.702163956193049, 0.599458156020645, 0.489172517239038,
0.957116387643539, 0.335186798948586, -0.598777825023964, 0.10012893280699,
0.0822063408722808, 0.393896776121708, 0.968441995451939, -0.625513747288306,
-0.437871585012806, 0.883606407251895), V19 = c(0.203243289070699,
0.206783154660488, 0.0730205054389099, 0.151752499129077, 0.339065300597841,
0.198750153846351, 0.246574181097875, 0.219716854159337, 0.112571755773366,
0.108437458425644, 0.159923853880819, 0.198217376539615, 1.27794667790059,
0.0628191359027579, -0.023668700184257, 0.0103470645871769, -4.55192891533295,
0.0932248108210876, 0.0372915017676821, 0.103290843005291, 0.1485089149749,
0.167015138770557, 0.258108289841612, 0.198988855325523), V20 = c(-0.6885610185506,
0.215106818871655, -1.26229703607397, -1.15415874394993, -0.770942786330788,
-1.07811513531511, -1.34581518035362, 0.296281823344214, -0.525449013409778,
1.52659228597052, 1.66011376586839, 0.204981756466606, 2.25710524990656,
0.850893107617607, 0.181598239123184, 0.0790398588000734, -0.0665218787774753,
0.411298611581292, 0.0839458342094344, -0.122405563089466, -1.6897393933796,
1.24061257187769, -0.157685318761091, -0.145878855645788), outcome_var = c(-3,
4, 1, -1, -1, -3, -1, -3, 3, 2, -2, -3, 1, 0, 0, 0, 3, 0, 2,
2, 1, -3, 1, 0)), class = "data.frame", row.names = c(NA, -24L
))
Code:代码:
train.control <- trainControl(method = "LOOCV")
step.model <- train(outcome_var ~., data = total,
method = "leapSeq",
tuneGrid = data.frame(nvmax = 1:5),
trControl = train.control
)
step.model$results
summary(step.model$finalModel)
Result:结果:
20 Variables (and intercept)
Forced in Forced out
V1 FALSE FALSE
V2 FALSE FALSE
V3 FALSE FALSE
V4 FALSE FALSE
V5 FALSE FALSE
V6 FALSE FALSE
V7 FALSE FALSE
V8 FALSE FALSE
V9 FALSE FALSE
V10 FALSE FALSE
V11 FALSE FALSE
V12 FALSE FALSE
V13 FALSE FALSE
V14 FALSE FALSE
V15 FALSE FALSE
V16 FALSE FALSE
V17 FALSE FALSE
V18 FALSE FALSE
V19 FALSE FALSE
V20 FALSE FALSE
1 subsets of each size up to 3
Selection Algorithm: 'sequential replacement'
V1 V2 V3 V4 V5 V6 V7 V8 V9 V10 V11 V12 V13 V14 V15 V16 V17 V18 V19 V20
1 ( 1 ) " " " " "*" " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " "
2 ( 1 ) " " " " "*" " " " " " " " " " " " " " " " " " " " " " " " " " " "*" " " " " " "
3 ( 1 ) " " " " " " " " " " " " " " " " " " "*" " " " " "*" " " " " " " "*" " " " " " "
Which gives me the output I'm looking for, but now I'm trying to make my own LOOCV function and not use the caret package for it.这给了我我正在寻找的 output,但现在我正在尝试制作自己的 LOOCV function,而不是使用插入符号 ZEFE90A8E604A7C840E88D03A67F6B7D8。 And I'm not getting the same results,
我没有得到相同的结果,
loocv = function(fit) {
n = length(fit$residuals)
yvar = fit$model[, 1]
index = 1:n
e = rep(NA, n)
for (i in index) {
refit = update(fit, subset = index != i)
pred = predict(refit, dplyr::slice(fit$model, i))
e[i] = yvar[i] - pred
}
return(mean(e^2))
}
How can I use LOOCV without using the caret package and find the best fitting model?如何在不使用插入符号 package 的情况下使用 LOOCV 并找到最合适的 model?
For cross-validation like LOOCV
, the model should be constructed from scratch for each testing fold.对于像
LOOCV
这样的交叉验证,model 应该为每个测试折叠从头开始构建。 By trial and error, I believe caret
uses leaps::regsubsets
for stepwise model selection.通过反复试验,我相信
caret
使用leaps::regsubsets
进行逐步 model 选择。
library(leaps)
nvmax = 3 #number of max variables
pred = rep(NA, nrow(total))
for (i in seq(nrow(total))) #LOOCV
{#train a new model
tem = regsubsets(x=total[-i,1:20],
y=total[-i,21],
nvmax=nvmax,
method="seqrep")
coef(tem, nvmax) #best coef chosen
fit = lm(outcome_var ~ .,
data = total[-i,
c(which(summary(tem)$which[nvmax,-1]),
21)])
#predict the hold-out data
pred[i] = predict(fit, newdata=total[i,])
}
RMSE(pred, total[,'outcome_var'])
#1.945036
MAE(pred, total[,'outcome_var'])
#1.442353
Results from caret:插入符号的结果:
step.model$results
# nvmax RMSE Rsquared MAE
# 3 1.945036 0.238655497 1.442353
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.