繁体   English   中英

如何在PyTorch中平衡(过采样)不平衡数据(使用WeightedRandomSampler)?

[英]How to balance (oversampling) unbalanced data in PyTorch (with WeightedRandomSampler)?

我有2类问题,我的数据高度不平衡。 我有来自一堂课的232550个样本和来自第二堂课的13498样本。 PyTorch文档和互联网告诉我为我的DataLoader使用类WeightedRandomSampler。

我已经尝试过使用WeightedRandomSampler,但是我一直收到错误消息。

    trainratio = np.bincount(trainset.labels)
    classcount = trainratio.tolist()
    train_weights = 1./torch.tensor(classcount, dtype=torch.float)
    train_sampleweights = train_weights[trainset.labels]
    train_sampler = WeightedRandomSampler(weights=train_sampleweights, 
    num_samples = len(train_sampleweights))
    trainloader = DataLoader(trainset, sampler=train_sampler, 
    shuffle=False)

初始化WeightedRandomSampler类时,为什么看不到此错误?

我尝试了其他类似的解决方法,但到目前为止,所有尝试均会产生一些错误。 我应该如何实施以平衡训练,验证和测试数据?

当前出现此错误:

train__sampleweights = train_weights [trainset.labels] ValueError:尺寸'str'过多

问题出在trainset.labels的类型中为了解决错误,可以将trainset.labels转换为float

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM