![](/img/trans.png)
[英]How to balance (oversampling) unbalanced data in PyTorch (with WeightedRandomSampler)?
[英]How to balance unbalanced data in PyTorch with WeightedRandomSampler?
我有2類問題,我的數據不平衡。 0類有232550個樣本,1類有13498個樣本。 PyTorch文檔和互聯網告訴我為我的DataLoader使用類WeightedRandomSampler。
我已經嘗試過使用WeightedRandomSampler,但是我一直收到錯誤消息。
trainratio = np.bincount(trainset.labels) #trainset.labels is a list of
float [0,1,0,0,0,...]
classcount = trainratio.tolist()
train_weights = 1./torch.tensor(classcount, dtype=torch.float)
train_sampleweights = train_weights[trainset.labels]
train_sampler = WeightedRandomSampler(weights=train_sampleweights,
num_samples=len(train_sampleweights))
trainloader = DataLoader(trainset, sampler=train_sampler,
shuffle=False)
我打印出一些尺寸:
train_weights = tensor([4.3002e-06, 4.3002e-06, 4.3002e-06, ...,
4.3002e-06, 4.3002e-06, 4.3002e-06])
train_weights shape= torch.Size([246048])
我看不到為什么出現此錯誤:
UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
self.weights = torch.tensor(weights, dtype=torch.double)
我嘗試了其他類似的解決方法,但到目前為止,所有嘗試均會產生一些錯誤。 我應該如何實施以平衡訓練,驗證和測試數據?
因此,顯然這是內部警告,而不是錯誤。 根據PyTorch的家伙說,我可以繼續編碼,而不必擔心警告消息。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.