簡體   English   中英

如何在PyTorch中平衡(過采樣)不平衡數據(使用WeightedRandomSampler)?

[英]How to balance (oversampling) unbalanced data in PyTorch (with WeightedRandomSampler)?

我有2類問題,我的數據高度不平衡。 我有來自一堂課的232550個樣本和來自第二堂課的13498樣本。 PyTorch文檔和互聯網告訴我為我的DataLoader使用類WeightedRandomSampler。

我已經嘗試過使用WeightedRandomSampler,但是我一直收到錯誤消息。

    trainratio = np.bincount(trainset.labels)
    classcount = trainratio.tolist()
    train_weights = 1./torch.tensor(classcount, dtype=torch.float)
    train_sampleweights = train_weights[trainset.labels]
    train_sampler = WeightedRandomSampler(weights=train_sampleweights, 
    num_samples = len(train_sampleweights))
    trainloader = DataLoader(trainset, sampler=train_sampler, 
    shuffle=False)

初始化WeightedRandomSampler類時,為什么看不到此錯誤?

我嘗試了其他類似的解決方法,但到目前為止,所有嘗試均會產生一些錯誤。 我應該如何實施以平衡訓練,驗證和測試數據?

當前出現此錯誤:

train__sampleweights = train_weights [trainset.labels] ValueError:尺寸'str'過多

問題出在trainset.labels的類型中為了解決錯誤,可以將trainset.labels轉換為float

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM