简体   繁体   中英

R - cross validation error handling-- "dims product do not match the length of object"

I have been working through some examples of statistical learning models via the ISLR package. The code is available here ( https://rpubs.com/davoodastaraky/subset ) so anyone can see. I also put it below for ease.

library(ISLR)
library(leaps)
data(Hitters)
Hitters
regfit.full = regsubsets(Salary ~ ., data = Hitters, nvmax = 19)
reg.summary = summary(regfit.full)
#plot rss
library(ggvis)
rsq <- as.data.frame(reg.summary$rsq)
names(rsq) <- "R2"
rsq %>% 
  ggvis(x=~ c(1:nrow(rsq)), y=~R2 ) %>%
  layer_points(fill = ~ R2 ) %>%
  add_axis("y", title = "R2") %>% 
  add_axis("x", title = "Number of variables")

par(mfrow=c(2,2))
plot(reg.summary$rss ,xlab="Number of Variables ",ylab="RSS",type="l")
plot(reg.summary$adjr2 ,xlab="Number of Variables ", ylab="Adjusted 
RSq",type="l")
# which.max(reg.summary$adjr2)
points(11,reg.summary$adjr2[11], col="red",cex=2,pch=20)
plot(reg.summary$cp ,xlab="Number of Variables ",ylab="Cp", type='l')
# which.min(reg.summary$cp )
points(10,reg.summary$cp [10],col="red",cex=2,pch=20)
plot(reg.summary$bic ,xlab="Number of Variables ",ylab="BIC",type='l')
# which.min(reg.summary$bic )
points(6,reg.summary$bic [6],col="red",cex=2,pch=20)

plot(regfit.full,scale="bic")

set.seed (1)
train = sample(c(TRUE,FALSE), nrow(Hitters),rep=TRUE)
test =(! train )

predict.regsubsets =function (object ,newdata ,id ,...){
  form=as.formula(object$call [[2]])
  mat=model.matrix(form,newdata)
  coefi=coef(object ,id=id)
  xvars=names(coefi)
  mat[,xvars]%*%coefi
}

regfit.best=regsubsets(Salary~.,data=Hitters ,nvmax=19)
coef(regfit.best ,10)

k = 10
set.seed(1)
folds = sample(1:k,nrow(Hitters),replace=TRUE)
table(folds)

The code runs smoothly until I get to this part below:

for(j in 1:k){
  best.fit = regsubsets(Salary ~., data=Hitters[folds != j,], nvmax = 19)

 for (i in 1:19){
pred = predict.regsubsets(best.fit, Hitters[folds == j, ], id = i)
cv.errors[j, i] = mean((Hitters$Salary[folds == j] - pred)^2)
  }
}

Where I get the error:

Error in mean((Hitters$Salary[folds == j] - pred)^2) : 
  dims [product 18] do not match the length of object [22]
In addition: Warning message:
In Hitters$Salary[folds == j] - pred :
  longer object length is not a multiple of shorter object length

My question is: Why am I getting this error? How do I fix it? The code is literally taken from the site and i haven't altered it in anyway. Clearly I am missing something about object length. Thanks.

If you want to "fix" this you will need to pull out the attributes of the pred object and then select matching values from the Hitters object based on its rownames() .

> str(Hitters$Salary)
 num [1:322] NA 475 480 500 91.5 750 70 100 75 1100 ...
> str(pred)
 num [1:18, 1] 988 359 370 808 383 ...
 - attr(*, "dimnames")=List of 2
  ..$ : chr [1:18] "-Andre Thornton" "-Bob Dernier" "-Chris Brown" "-Chet Lemon" ...
  ..$ : NULL
> names(Hitters)
 [1] "AtBat"     "Hits"      "HmRun"     "Runs"      "RBI"       "Walks"     "Years"     "CAtBat"   
 [9] "CHits"     "CHmRun"    "CRuns"     "CRBI"      "CWalks"    "League"    "Division"  "PutOuts"  
[17] "Assists"   "Errors"    "Salary"    "NewLeague"
> rownames(Hitters)
  [1] "-Andy Allanson"     "-Alan Ashby"        "-Alvin Davis"       "-Andre Dawson"     
  [5] "-Andres Galarraga"  "-Alfredo Griffin"   "-Al Newman"         "-Argenis Salazar"  
  [9] "-Andres Thomas"     "-Andre Thornton"    "-Alan Trammell"     "-Alex Trevino"     
 [13] "-Andy VanSlyke"     "-Alan Wiggins"      "-Bill Almon"        "-Billy Beane"
#omitted the rest of the 322-item column     

There are missing values in the salary column of Hitters dataset. Just drop them, then the works as expected.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM