简体   繁体   中英

How to fit custom data into Pytorch DataLoader?

I have pre-processed and normalized my data, and split into training set and testing set. I have the following dimensions for my x_train and y_train: Shape of X_Train: (708, 256, 3) Shape of Y_Train: (708, 4)

As you can see, x_train is 3-D. How can I go about inputting it into the pytorch dataloader? What do I put for the class block?

class training_set(Dataset):
    def __init__(self,X,Y):

    def __len__(self):
        return 

    def __getitem__(self, idx):
        return 

training_set = torch.utils.data.TensorDataset(x_train, y_train)
train_loader = torch.utils.data.DataLoader(training_set, batch_size=50, shuffle=True)
x_train, y_train = torch.rand((708, 256, 3)), torch.rand((708, 4))  # data

class training_set(data.Dataset):
    def __init__(self,X,Y):
        self.X = X                           # set data
        self.Y = Y                           # set lables

    def __len__(self):
        return len(self.X)                   # return length

    def __getitem__(self, idx):
        return [self.X[idx], self.Y[idx]]    # return list of batch data [data, labels]
training_dataset = training_set(x_train, y_train)
train_loader = torch.utils.data.DataLoader(training_dataset, batch_size=50, shuffle=True)

Actually you don't need to use a custom data set because in your case is simple dataset. You can first change to TensorDataset so that you can use

training_dataset = torch.utils.data.TensorDataset(x_train, y_train)
train_loader = torch.utils.data.DataLoader(training_dataset, batch_size=50, shuffle=True)

both will return same results.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM