[英]Multi-label, multi-class image classifier (ConvNet) with PyTorch
我正在嘗試使用PyTorch實現圖像分類器(CNN / ConvNet),在這里我想從csv文件讀取標簽。 我有4個不同的類別,一張圖片可能屬於多個類別。
我已經閱讀了PyTorch教程 , 本斯坦福教程以及本 教程 ,但都沒有涵蓋我的具體情況。 我設法建立了torch.utils.data.Dataset
類的自定義函數,該函數僅對於從二進制分類器的csv文件讀取標簽有效。
這是我到目前為止擁有的torch.utils.data.Dataset
類的代碼(與上面鏈接的第三個教程稍作修改):
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
from PIL import Image
import numpy as np
import pandas as pd
class MyCustomDataset(data.Dataset):
# __init__ function is where the initial logic happens like reading a csv,
# assigning transforms etc.
def __init__(self, csv_path):
# Transforms
self.random_crop = transforms.RandomCrop(800)
self.to_tensor = transforms.ToTensor()
# Read the csv file
self.data_info = pd.read_csv(csv_path, header=None)
# First column contains the image paths
self.image_arr = np.asarray(self.data_info.iloc[:, 0])
# Second column is the labels
self.label_arr = np.asarray(self.data_info.iloc[:, 1])
# Calculate len
self.data_len = len(self.data_info.index)
# __getitem__ function returns the data and labels. This function is
# called from dataloader like this
def __getitem__(self, index):
# Get image name from the pandas df
single_image_name = self.image_arr[index]
# Open image
img_as_img = Image.open(single_image_name)
img_cropped = self.random_crop(img_as_img)
img_as_tensor = self.to_tensor(img_cropped)
# Get label(class) of the image based on the cropped pandas column
single_image_label = self.label_arr[index]
return (img_as_tensor, single_image_label)
def __len__(self):
return self.data_len
具體來說,我正在嘗試從具有以下結構的文件中讀取標簽:
我的具體問題是,我無法弄清楚如何將其實現到我的Dataset
類中。 我想我在csv中的標簽的(手動)分配與PyTorch如何讀取它們之間缺少聯系,因為我對框架不是很熟悉。
我非常感謝您提供有關如何使其正常工作的幫助,或者如果確實有涉及此方面的示例,那么也將非常感謝您提供鏈接!
也許我缺少了一些東西,但是如果您想將列1..N
(此處N = 4
)轉換為標簽向量或形狀(N,)
(例如,給出示例數據, label(img1) = [0, 0, 0, 1]
, label(img3) = [1, 0, 1, 0]
,...),為什么不:
將所有標簽列讀入self.label_arr
:
self.label_arr = np.asarray(self.data_info.iloc[:, 1:]) # columns 1 to N
相應地返回__getitem__()
的標簽(此處不變):
single_image_label = self.label_arr[index]
為了訓練您的分類器,您可以然后計算例如(N,)
預測與目標標簽之間的交叉熵。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.