[英]How to balance unbalanced data in PyTorch with WeightedRandomSampler?
[英]How to balance (oversampling) unbalanced data in PyTorch (with WeightedRandomSampler)?
我有2類問題,我的數據高度不平衡。 我有來自一堂課的232550個樣本和來自第二堂課的13498
樣本。 PyTorch文檔和互聯網告訴我為我的DataLoader使用類WeightedRandomSampler。
我已經嘗試過使用WeightedRandomSampler,但是我一直收到錯誤消息。
trainratio = np.bincount(trainset.labels)
classcount = trainratio.tolist()
train_weights = 1./torch.tensor(classcount, dtype=torch.float)
train_sampleweights = train_weights[trainset.labels]
train_sampler = WeightedRandomSampler(weights=train_sampleweights,
num_samples = len(train_sampleweights))
trainloader = DataLoader(trainset, sampler=train_sampler,
shuffle=False)
初始化WeightedRandomSampler類時,為什么看不到此錯誤?
我嘗試了其他類似的解決方法,但到目前為止,所有嘗試均會產生一些錯誤。 我應該如何實施以平衡訓練,驗證和測試數據?
當前出現此錯誤:
train__sampleweights = train_weights [trainset.labels] ValueError:尺寸'str'過多
問題出在trainset.labels的類型中為了解決錯誤,可以將trainset.labels轉換為float
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.