簡體   English   中英

如何使用WeightedRandomSampler平衡PyTorch中的不平衡數據?

[英]How to balance unbalanced data in PyTorch with WeightedRandomSampler?

我有2類問題,我的數據不平衡。 0類有232550個樣本,1類有13498個樣本。 PyTorch文檔和互聯網告訴我為我的DataLoader使用類WeightedRandomSampler。

我已經嘗試過使用WeightedRandomSampler,但是我一直收到錯誤消息。

    trainratio = np.bincount(trainset.labels) #trainset.labels is a list of 
    float [0,1,0,0,0,...] 
    classcount = trainratio.tolist()
    train_weights = 1./torch.tensor(classcount, dtype=torch.float)
    train_sampleweights = train_weights[trainset.labels]
    train_sampler = WeightedRandomSampler(weights=train_sampleweights, 
                                 num_samples=len(train_sampleweights))
    trainloader = DataLoader(trainset, sampler=train_sampler, 
                                       shuffle=False)

我打印出一些尺寸:

train_weights = tensor([4.3002e-06, 4.3002e-06, 4.3002e-06,  ..., 
4.3002e-06, 4.3002e-06, 4.3002e-06])

train_weights shape=  torch.Size([246048])

我看不到為什么出現此錯誤:

UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.weights = torch.tensor(weights, dtype=torch.double)

我嘗試了其他類似的解決方法,但到目前為止,所有嘗試均會產生一些錯誤。 我應該如何實施以平衡訓練,驗證和測試數據?

因此,顯然這是內部警告,而不是錯誤。 根據PyTorch的家伙說,我可以繼續編碼,而不必擔心警告消息。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM