簡體   English   中英

使用 pytorch 對數據集進行過采樣

[英]Oversampling the dataset with pytorch

我對 PyTorch 和 python 很陌生。 我有一個二進制分類問題,其中一個 class 的樣本比另一個多,所以我決定通過對其進行更多增強來對樣本數量較少的 class 進行過采樣,例如,我將從一個樣本中生成 7 個圖像一個 class,而另一個 class 我將從一個樣本中生成 3 個圖像。 我正在使用 imguag 進行 PyTorch 的擴充,所以我不確定哪個更好,首先擴充我的數據集,然后將其傳遞給 torch.utils.data.Dataset class,或讀取數據並在init ZC1C4D1AB1545Z777778 中擴充它數據集 class。

我認為還有另一種處理不平衡數據的方法, nn.BCELoss是二進制分類問題的常見選擇,可以設置一個pos_weight來平衡正負樣本。 如果這樣做,您可以對所有樣本應用相同的增強。 這是代碼:

# defines the augmentation
transform = transforms.Compose([transforms.RandomRotation(20),
                            transforms.Resize((32, 32)),
                            transforms.ToTensor(),
                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# initializes the data set
dataset = Dataset(train_data_path, transforms=transform)
# defines the loss function
criterion = torch.nn.BCELoss(torch.tensor([10.]))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM