简体   繁体   English

map2 在两个列表上并在新列表中创建一个矩阵

[英]map2 over two lists and create a matrix in a new list

I have two lists, each with 2 observations.我有两个列表,每个列表都有 2 个观察值。

I would like to use map2 to create a new list and set up a matrix.我想使用map2创建一个新列表并设置一个矩阵。 What I currently have is the following:我目前拥有的是以下内容:

library(xgboost)
dtrain <- map2(
  X_list, Y_list ~ xgb.DMatrix(data = .x, label = .y, missing = "NaN")
)

How can I use map2 to map over X_list and Y_list and create the xgb.DMatrix ?如何通过X_listY_list使用map2到 map 并创建xgb.DMatrix

The following works for one observation:以下适用于一项观察:

dtrain <- xgb.DMatrix(data = X_list[[1]], label = Y_list[[1]], missing = "NaN")

Data:数据:

X_list <- list(structure(c(-0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, 
-0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.25, 
-0.25, 0.626262626262626, 0.626262626262626, 0.626262626262626, 
0.626262626262626, 0.616161616161616, 0.616161616161616, 0.616161616161616, 
0.626262626262626, 0.606060606060606, 0.606060606060606, 0.616161616161616, 
0.626262626262626, 0.616161616161616, 0.626262626262626, 0.646464646464647, 
0.646464646464647, 0.653061224489796, 0.606060606060606, 0.595959595959596, 
0.595959595959596, 0.797979797979798, 0.818181818181818, 0.797979797979798, 
0.797979797979798, 0.797979797979798, 0.797979797979798, 0.797979797979798, 
0.808080808080808, 0.787878787878788, 0.787878787878788, 0.797979797979798, 
0.808080808080808, 0.787878787878788, 0.808080808080808, 0.818181818181818, 
0.818181818181818, 0.76530612244898, 0.767676767676768, 0.757575757575758, 
0.747474747474748, 0.0525742729373827, 0.0247209518021764, 0.0383024701142363, 
0.0431858541848299, 0.0373874353552574, 0.043102816902322, 0.0393864266632971, 
0.0453441092054323, -0.0332326842437988, -0.0673697134686685, 
-0.0756843013225178, -0.0944028996710843, -0.0889603051584127, 
-0.0839588622987959, -0.0693978939540514, -0.068952460676042, 
-0.0623124172598907, -0.0289416799234536, -0.0263188782910003, 
-0.0539180933744947, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 
3, 3, 3, 1, 1, 1, 1.41915800843232, 1.32400213226479, 1.24248186489683, 
1.23904026693494, 1.22578149610895, 1.22560759842505, 1.23716050117333, 
1.31093157484879, 1.36495946627377, 1.36987409128001, 1.32677838199057, 
1.59470485996637, 1.47030836098393, 1.38716143306261, 1.41122611821048, 
1.50470462559195, 2.5667628920672, 3.25998616689164, 2.72418353696655, 
2.77147826857654, 2.00816529922485, 1.9712685855706, 1.93614308715926, 
1.9851930297389, 1.9851930297389, 1.9851930297389, 1.9851930297389, 
1.9851930297389, 1.93614308715926, 1.93678125811656, 1.93994682647747, 
1.93994682647747, 1.93678125811656, 1.93994682647747, 1.98219354714358, 
1.98453679538454, 2.02345566404527, 1.9393462397422, 1.92741709002804, 
1.93456058692163, 0.555555555555556, 0.555555555555556, 0.555555555555556, 
0.545454545454545, 0.545454545454545, 0.545454545454545, 0.545454545454545, 
0.545454545454545, 0.545454545454545, 0.545454545454545, 0.535353535353535, 
0.535353535353535, 0.545454545454545, 0.535353535353535, 0.525252525252525, 
0.525252525252525, 0.525252525252525, 0.535353535353535, 0.505050505050505, 
0.505050505050505, 0.0962206695047701, 0.0645033172611389, 0.0789706705735339, 
0.0871419915715422, 0.0696361276044722, 0.0611974185989991, 0.0535584181214952, 
0.0733027435074016, 0.104008467087853, 0.0643646276048721, 0.0599217594491127, 
0.045252967966837, 0.0648458104086725, 0.0665422705724929, 0.0722633466679387, 
0.0579149894109999, 0.0671261068882733, 0.0659758622985851, 0.0902909795091504, 
0.0667567811930188), .Dim = c(20L, 9L), .Dimnames = list(NULL, 
    c("X1", "X2", "X3", "X4", "X5", "X6", "X7", "X8", "X9"))), 
    structure(c(-0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, 
    -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.25, 
    -0.25, -0.5, 0.626262626262626, 0.626262626262626, 0.626262626262626, 
    0.616161616161616, 0.616161616161616, 0.616161616161616, 
    0.626262626262626, 0.606060606060606, 0.606060606060606, 
    0.616161616161616, 0.626262626262626, 0.616161616161616, 
    0.626262626262626, 0.646464646464647, 0.646464646464647, 
    0.653061224489796, 0.606060606060606, 0.595959595959596, 
    0.595959595959596, 0.606060606060606, 0.818181818181818, 
    0.797979797979798, 0.797979797979798, 0.797979797979798, 
    0.797979797979798, 0.797979797979798, 0.808080808080808, 
    0.787878787878788, 0.787878787878788, 0.797979797979798, 
    0.808080808080808, 0.787878787878788, 0.808080808080808, 
    0.818181818181818, 0.818181818181818, 0.76530612244898, 0.767676767676768, 
    0.757575757575758, 0.747474747474748, 0.747474747474748, 
    0.0247209518021764, 0.0383024701142363, 0.0431858541848299, 
    0.0373874353552574, 0.043102816902322, 0.0393864266632971, 
    0.0453441092054323, -0.0332326842437988, -0.0673697134686685, 
    -0.0756843013225178, -0.0944028996710843, -0.0889603051584127, 
    -0.0839588622987959, -0.0693978939540514, -0.068952460676042, 
    -0.0623124172598907, -0.0289416799234536, -0.0263188782910003, 
    -0.0539180933744947, -0.0784106030875784, 3, 3, 3, 3, 3, 
    3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 1, 1, 1.32400213226479, 
    1.24248186489683, 1.23904026693494, 1.22578149610895, 1.22560759842505, 
    1.23716050117333, 1.31093157484879, 1.36495946627377, 1.36987409128001, 
    1.32677838199057, 1.59470485996637, 1.47030836098393, 1.38716143306261, 
    1.41122611821048, 1.50470462559195, 2.5667628920672, 3.25998616689164, 
    2.72418353696655, 2.77147826857654, 2.8300838173709, 1.9712685855706, 
    1.93614308715926, 1.9851930297389, 1.9851930297389, 1.9851930297389, 
    1.9851930297389, 1.9851930297389, 1.93614308715926, 1.93678125811656, 
    1.93994682647747, 1.93994682647747, 1.93678125811656, 1.93994682647747, 
    1.98219354714358, 1.98453679538454, 2.02345566404527, 1.9393462397422, 
    1.92741709002804, 1.93456058692163, 1.9393462397422, 0.555555555555556, 
    0.555555555555556, 0.545454545454545, 0.545454545454545, 
    0.545454545454545, 0.545454545454545, 0.545454545454545, 
    0.545454545454545, 0.545454545454545, 0.535353535353535, 
    0.535353535353535, 0.545454545454545, 0.535353535353535, 
    0.525252525252525, 0.525252525252525, 0.525252525252525, 
    0.535353535353535, 0.505050505050505, 0.505050505050505, 
    0.535353535353535, 0.0645033172611389, 0.0789706705735339, 
    0.0871419915715422, 0.0696361276044722, 0.0611974185989991, 
    0.0535584181214952, 0.0733027435074016, 0.104008467087853, 
    0.0643646276048721, 0.0599217594491127, 0.045252967966837, 
    0.0648458104086725, 0.0665422705724929, 0.0722633466679387, 
    0.0579149894109999, 0.0671261068882733, 0.0659758622985851, 
    0.0902909795091504, 0.0667567811930188, 0.0719090569241118
    ), .Dim = c(20L, 9L), .Dimnames = list(NULL, c("X1", "X2", 
    "X3", "X4", "X5", "X6", "X7", "X8", "X9"))))


Y_list <- list(structure(c(0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 
1, 0, 1, 1, 1), .Dim = c(20L, 1L), .Dimnames = list(NULL, "Y_plus_1")), 
    structure(c(1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 
    0, 1, 1, 1, 0), .Dim = c(20L, 1L), .Dimnames = list(NULL, 
        "Y_plus_1")))

EDIT: Updated problem:编辑:更新的问题:

I can create the matrix as:我可以将矩阵创建为:

dtrain <- map2(
  X_list, Y_list, ~ xgb.DMatrix(data = .x, label = .y, missing = "NaN")
)

dtest <- map(
  X_list, ~ xgb.DMatrix(data = .x, missing = "NaN")
)

watchlist <- list("train" = dtrain)

params <- list("eta" = 0.1, "max_depth" = 5, "colsample_bytree" = 1, "min_child_weight" = 1, "subsample"= 1,
               "objective"="binary:logistic", "gamma" = 1, "lambda" = 1, "alpha" = 0, "max_delta_step" = 0,
               "colsample_bylevel" = 1, "eval_metric"= "auc",
               "set.seed" = 176)

xgb.model <- map(
  dtrain, ~ xgboost(params = params, data = .x, nrounds = 100, watchlist)
)

All the above works.以上所有工作。 My problem is with the prediction function:我的问题是预测 function:

xgb.pred <- map(
  dtest, ~ predict(xgb.model, data = .x, type = 'prob')
)

I get this error:我收到此错误:

Error in ets(object, lambda = lambda, biasadj = biasadj, allow.multiplicative.trend = allow.multiplicative.trend, : y should be a univariate time series ets 中的错误(对象,lambda = lambda,biasadj = biasadj,allow.multiplicative.trend = allow.multiplicative.trend,:y 应该是单变量时间序列

I now have a large list xgb.model and a normal list dtest along with a function predict .我现在有一个大列表xgb.model和一个普通列表dtest以及一个 function predict

This fails also:这也失败了:

xgb.pred <- map2(
  dtest, xgb.model, ~ predict(object = .x, model = .y, type = 'prob')
)

Error:错误:

Error in forecast.ts(object, ...): Unknown model class predict.ts(object, ...) 中的错误:未知 model class

(I can provide more data if needed) (如果需要,我可以提供更多数据)

Try this尝试这个

xgb.pred <- map2(
  .x = xgb.model, 
  .y = dtest, 
  .f = ~ predict(.x, newdata = .y, type = 'prob')
)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM