简体   繁体   中英

How do I make custom pytorch datasets structured like the torchvision datasets?

I'm new to pytorch and I'm trying to reuse a Fashion MNIST CNN ( from deeplizard ) to categorize my timeseries data. I'm finding it hard to understand the structure of datasets, because following this official tutorial and this SO question as best I can, I'm getting something too simple. I think this is because I don't understand OOP very well. The dataset I've made works fine in my CNN for training but then trying to analyse the results with their code I get stuck.

So I create a dataset from two pytorch tensors called features [4050, 1, 150, 6] and targets[4050]:

train_dataset = TensorDataset(features,targets) # create your datset
train_dataloader = DataLoader(train_dataset, batch_size=50, shuffle=False) # create your dataloader
print(train_dataset.__dict__.keys()) # list the attributes

I get this printed output from inspecting the attributes

dict_keys(['tensors'])

But in the Fashion MNIST tutorial they access the data like this:

train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=1000, shuffle=True)
print(train_set.__dict__.keys()) # list the attributes

And you get this printed output from inspecting the attributes

dict_keys(['root', 'transform', 'target_transform', 'transforms', 'train', 'data', 'targets'])

My dataset works fine for training but when I get to later analysis parts of the tutorial, they want me to access parts of the dataset and I get an error:

# Analytics
prediction_loader = torch.utils.data.DataLoader(train_dataset, batch_size=50)
train_preds = get_all_preds(network, prediction_loader)
preds_correct = train_preds.argmax(dim=1).eq(train_dataset.targets).sum().item()

print('total correct:', preds_correct)
print('accuracy:', preds_correct / len(train_set))

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-73-daa87335a92a> in <module>
      4 prediction_loader = torch.utils.data.DataLoader(train_dataset, batch_size=50)
      5 train_preds = get_all_preds(network, prediction_loader)
----> 6 preds_correct = train_preds.argmax(dim=1).eq(train_dataset.targets).sum().item()
      7 
      8 print('total correct:', preds_correct)

AttributeError: 'TensorDataset' object has no attribute 'targets'

Can anyone tell me what's going on here? Is this something I need to change in how I make the datasets, or can I rewrite the analysis code somehow to access the right part of the dataset?

The equivalent of .targets for TensorDataset s would be train_dataset.tensors[1] .

The implementation of TensorDataset is very simple:

class TensorDataset(Dataset[Tuple[Tensor, ...]]):
    r"""Dataset wrapping tensors.
    Each sample will be retrieved by indexing tensors along the first dimension.
    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    tensors: Tuple[Tensor, ...]

    def __init__(self, *tensors: Tensor) -> None:
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM