簡體   English   中英

帶有PyTorch的多標簽,多類別圖像分類器(ConvNet)

[英]Multi-label, multi-class image classifier (ConvNet) with PyTorch

我正在嘗試使用PyTorch實現圖像分類器(CNN / ConvNet),在這里我想從csv文件讀取標簽。 我有4個不同的類別,一張圖片可能屬於多個類別。

我已經閱讀了PyTorch教程本斯坦福教程以及 教程 ,但都沒有涵蓋我的具體情況。 我設法建立了torch.utils.data.Dataset類的自定義函數,該函數僅對於從二進制分類器的csv文件讀取標簽有效。

這是我到目前為止擁有的torch.utils.data.Dataset類的代碼(與上面鏈接的第三個教程稍作修改):

import torch
import torchvision.transforms as transforms
import torch.utils.data as data
from PIL import Image
import numpy as np
import pandas as pd


class MyCustomDataset(data.Dataset):
# __init__ function is where the initial logic happens like reading a csv,
# assigning transforms etc.
def __init__(self, csv_path):
    # Transforms
    self.random_crop = transforms.RandomCrop(800)
    self.to_tensor = transforms.ToTensor()
    # Read the csv file
    self.data_info = pd.read_csv(csv_path, header=None)
    # First column contains the image paths
    self.image_arr = np.asarray(self.data_info.iloc[:, 0])
    # Second column is the labels
    self.label_arr = np.asarray(self.data_info.iloc[:, 1])
    # Calculate len
    self.data_len = len(self.data_info.index)


# __getitem__ function returns the data and labels. This function is
# called from dataloader like this
def __getitem__(self, index):
    # Get image name from the pandas df
    single_image_name = self.image_arr[index]
    # Open image
    img_as_img = Image.open(single_image_name)
    img_cropped = self.random_crop(img_as_img)
    img_as_tensor = self.to_tensor(img_cropped)

    # Get label(class) of the image based on the cropped pandas column
    single_image_label = self.label_arr[index]

    return (img_as_tensor, single_image_label)

def __len__(self):
    return self.data_len

具體來說,我正在嘗試從具有以下結構的文件中讀取標簽:

CSV數據

我的具體問題是,我無法弄清楚如何將其實現到我的Dataset類中。 我想我在csv中的標簽的(手動)分配與PyTorch如何讀取它們之間缺少聯系,因為我對框架不是很熟悉。
我非常感謝您提供有關如何使其正常工作的幫助,或者如果確實有涉及此方面的示例,那么也將非常感謝您提供鏈接!

也許我缺少了一些東西,但是如果您想將列1..N (此處N = 4 )轉換為標簽向量或形狀(N,) (例如,給出示例數據, label(img1) = [0, 0, 0, 1]label(img3) = [1, 0, 1, 0] ,...),為什么不:

  1. 將所有標簽列讀入self.label_arr

     self.label_arr = np.asarray(self.data_info.iloc[:, 1:]) # columns 1 to N 
  2. 相應地返回__getitem__()的標簽(此處不變):

     single_image_label = self.label_arr[index] 

為了訓練您的分類器,您可以然后計算例如(N,)預測與目標標簽之間的交叉熵。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM