简体   繁体   中英

How to balance (oversampling) unbalanced data in PyTorch (with WeightedRandomSampler)?

I have a 2-class problem and my data is highly unbalanced. I have 232550 samples from one class and 13498 from the second class. PyTorch docs and the internet tells me to use the class WeightedRandomSampler for my DataLoader.

I have tried using the WeightedRandomSampler but I keep getting errors.

    trainratio = np.bincount(trainset.labels)
    classcount = trainratio.tolist()
    train_weights = 1./torch.tensor(classcount, dtype=torch.float)
    train_sampleweights = train_weights[trainset.labels]
    train_sampler = WeightedRandomSampler(weights=train_sampleweights, 
    num_samples = len(train_sampleweights))
    trainloader = DataLoader(trainset, sampler=train_sampler, 
    shuffle=False)

I can not see why I am getting this error when initializing the WeightedRandomSampler class?

I have tried other similar workarounds but so far all attempts produce some error. How should I implement this to balance my train, validation and test data?

Currently getting this error:

train__sampleweights = train_weights[trainset.labels] ValueError: too many dimensions 'str'

问题出在trainset.labels的类型中为了解决错误,可以将trainset.labels转换为float

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM