簡體   English   中英

如何在 pytorch 中更新我的 ImageFolder 數據集?

[英]How do I update my ImageFolder dataset in pytorch?

我正在處理一個數據集,我需要在其中找到少於 20 個樣本的類的准確性。 所以首先我使用 pytorch 的 ImageFolder 來獲取文件夾中的所有圖像。

dataset = ImageFolder('/content/drive/MyDrive/data/Dataset/')

現在要獲得少於 20 個樣本的類,我使用:

def get_class_distribution(dataset_obj):
    count_dict = {k:0 for k,v in dataset_obj.class_to_idx.items()}
    
    for element in dataset_obj:
        y_lbl = element[1]
        y_lbl = idx2class[y_lbl]
        count_dict[y_lbl] += 1
            
    return count_dict
# print("Distribution of classes: \n", get_class_distribution(dataset))
class_distribution = get_class_distribution(dataset)

sampled_classes = [classes  for (classes, samples) in class_distribution.items() if samples <= 20]

我正確獲得了類列表,但我懷疑如何進一步進行推理? 如何將其轉換/更新為 ImageFolder 以便我可以在以下代碼中使用過濾后的數據集:

# Test model performance for classes with less than 20 samples.

y_pred_list = []
y_true_list = []
with torch.no_grad():
    for x_batch, y_batch in tqdm(data_loader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        y_test_pred = model(x_batch)
        _, y_pred_tag = torch.max(y_test_pred, dim = 1)
        y_pred_list.append(y_pred_tag.cpu().numpy())
        y_true_list.append(y_batch.cpu().numpy())

不需要寫第一個塊
改用這個

test_data = datasets.ImageFolder('test/', transform=test_transforms)
data_loader = torch.utils.data.DataLoader(test_data, batch_size=16)

y_pred_list = []
accuracy = []
with torch.no_grad():
    for x_batch, y_batch in tqdm(data_loader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        y_test_pred = model(x_batch)
        top_p, top_class = y_test_pred.topk(1, dim=1)
        equals = top_class == y_batch.view(*top_class.shape)
        accuracy += torch.mean(equals.type(torch.FloatTensor)).item()


print(accuracy/len(data_loader)*100) # this would print %

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM