[英]How to split data into train and test sets from one directory with PyTorch?
我有一個數據文件夾,該文件夾沒有將數據拆分為訓練文件夾和測試文件夾。 如何將數據拆分為訓練集和測試集? 標簽來自文件的名稱,因此該順序的任何更改都必須包括標簽。 我想在使用 ImageFolder 之前拆分數據,以便可以在訓練和測試數據集上完成不同的轉換。
train_transforms = transforms.Compose([transforms.RandomRotation(10),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
test_transforms = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
train_image_dataset = datasets.ImageFolder(data_dir, transform=train_transforms)
test_image_dataset = datasets.ImageFolder(data_dir, transform=test_transforms)
train_dataloader = torch.utils.data.DataLoader(train_image_dataset, batch_size=64, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_image_dataset, batch_size=32)
我認為您正在尋找的是交叉驗證,請檢查此答案。 您可以添加帶有標簽的列,然后應用任何交叉驗證方法來拆分為測試和訓練。
要拆分圖像文件以進行訓練和測試,您可以使用來自 tensorflow 的 ImageDataGenerator,
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
#1st your image directory
files_dir = "imagefiles/rootdirectory"
image_size = (96,96)
Batch_size = 32
# create your Image data generator object with the normalization and other transformations you need you can review
#all the transformations you can do from the link at the end
#create your training data generator with your splitting percentage !!!!
#Immportant thing if your are using a multiclass dataset parameter class_mode should be "categorical"
training_gen = ImageDataGenerator(rescale = (1/256),validation_split = 0.3)
validation_gen = ImageDataGenrator(rescale = (1/256),validation_split = 0.3)
Training_set = training_gen.flow_from_directory(files_dir,
batch_size =Batch_size,
target_size = Image_size,
class_mode="binary")
Validation_set = validation_gen.flow_from_directory(validation_dir,
target_size = Image_size,
batch_size = Batch_size,
class_mode = "binary")
https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.