简体   繁体   中英

Tidymodels: Creating an rsplit object from training and testing data

I'm trying to make the jump from Scikit-Learn to Tidymodels and most of the time it has been relatively painless thanks to the tutorials from Julia Silge and Andrew Couch. However, now I'm stuck. Normally I would use initial_split(df, strata = x) to get a split object to work with. But this time I've been provided with the test and train sets from a different department and I'm afraid this might become the norm. Without a split object functions like last_fit() and collect_predictions() don't work.

How can I reverse engineer the provided datasets so that they become rsplit objects? Or alternatively, is it possible to bind the datasets together first and then tell initial_split() exactly what rows should go to train and test?

I see that someone asked the same question at https://community.rstudio.com/t/tidymodels-creating-a-split-object-from-testing-and-training-data-perform-last-fit/69885 . Max Kuhn said you could reverse engineer an rsplit object but I didn't understand how. Thanks!

# Example data
train <- tibble(predictor = c(0, 1, 1, 1, 0, 1, 0, 0),
       feature_1 = c(12, 18, 15, 5, 20, 2, 6, 10),
       feature_2 = c(120, 98, 111, 67, 335, 123, 22, 69))

test <- tibble(predictor = c(0, 1, 0, 1),
       feature_1 = c(5, 13, 8, 9),
       feature_2 = c(132, 105, 99, 112))

Reverse engineering the split object is likely meaning simply looking at the construction of the rsplit object. Depending on the package implementation this can be as simple as reconstructing the object with the same fields as the ones that comes when using initial_split . This is most likely the case here, so we'd simply have to recreate the object and make certain all the fields are available.

One method however (likely the simplest) would be to combine the two data.frames and use indices together with make_splits to recreate the the original split pair

library(rsample)
library(dplyr)
combined <- bind_rows(train, test)
ind <- list(analysis = seq(nrow(train)), assessment = nrow(train) + seq(nrow(test)))
splits <- make_splits(ind, combined)
splits
<Analysis/Assess/Total>
<8/4/12>

I can think of using initial_time_split() , as it takes the first prop samples for training, instead of a random selection.

library(tidymodels)
#> -- Attaching packages ---------- tidymodels 0.1.1 --

train <- tibble(predictor = c(0, 1, 1, 1, 0, 1, 0, 0),
                feature_1 = c(12, 18, 15, 5, 20, 2, 6, 10),
                feature_2 = c(120, 98, 111, 67, 335, 123, 22, 69))

test <- tibble(predictor = c(0, 1, 0, 1),
               feature_1 = c(5, 13, 8, 9),
               feature_2 = c(132, 105, 99, 112))

data <- bind_rows(train, test)

prop = nrow(train) / (nrow(train) + nrow(test))

split <- initial_time_split(data, prop = prop)

train_split <- training(split)
test_split <- testing(split)

all_equal(train, train_split)
#> [1] TRUE
all_equal(test, test_split)
#> [1] TRUE

Created on 2020-09-22 by the reprex package (v0.3.0)

Reverse Engineering

If you check the structure of the rsplit object, you will see that it has $in_id that lists the row ids of the training samples. You can change this manually to include the predefined train row numbers.

library(tidymodels)
#> -- Attaching packages -- tidymodels 0.1.1 --

train <- tibble(predictor = c(0, 1, 1, 1, 0, 1, 0, 0),
                feature_1 = c(12, 18, 15, 5, 20, 2, 6, 10),
                feature_2 = c(120, 98, 111, 67, 335, 123, 22, 69))

test <- tibble(predictor = c(0, 1, 0, 1),
               feature_1 = c(5, 13, 8, 9),
               feature_2 = c(132, 105, 99, 112))

data <- bind_rows(train, test, .id = "dataset") %>% 
  mutate(dataset = factor(dataset, labels = c("train", "test")))

train_ids <- which(data$dataset == "train")

split <- initial_split(data)

# change split$in_id to include the predefined train samples
split$in_id <- train_ids

train_split <- training(split) %>% 
  select(-dataset)
test_split <- testing(split) %>% 
  select(-dataset)

all_equal(train, train_split)
#> [1] TRUE
all_equal(test, test_split)
#> [1] TRUE

Created on 2020-09-22 by the reprex package (v0.3.0)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM