[英]How to split data into train and test sets from one directory with PyTorch?
我有一个数据文件夹,该文件夹没有将数据拆分为训练文件夹和测试文件夹。 如何将数据拆分为训练集和测试集? 标签来自文件的名称,因此该顺序的任何更改都必须包括标签。 我想在使用 ImageFolder 之前拆分数据,以便可以在训练和测试数据集上完成不同的转换。
train_transforms = transforms.Compose([transforms.RandomRotation(10),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
test_transforms = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
train_image_dataset = datasets.ImageFolder(data_dir, transform=train_transforms)
test_image_dataset = datasets.ImageFolder(data_dir, transform=test_transforms)
train_dataloader = torch.utils.data.DataLoader(train_image_dataset, batch_size=64, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_image_dataset, batch_size=32)
我认为您正在寻找的是交叉验证,请检查此答案。 您可以添加带有标签的列,然后应用任何交叉验证方法来拆分为测试和训练。
要拆分图像文件以进行训练和测试,您可以使用来自 tensorflow 的 ImageDataGenerator,
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
#1st your image directory
files_dir = "imagefiles/rootdirectory"
image_size = (96,96)
Batch_size = 32
# create your Image data generator object with the normalization and other transformations you need you can review
#all the transformations you can do from the link at the end
#create your training data generator with your splitting percentage !!!!
#Immportant thing if your are using a multiclass dataset parameter class_mode should be "categorical"
training_gen = ImageDataGenerator(rescale = (1/256),validation_split = 0.3)
validation_gen = ImageDataGenrator(rescale = (1/256),validation_split = 0.3)
Training_set = training_gen.flow_from_directory(files_dir,
batch_size =Batch_size,
target_size = Image_size,
class_mode="binary")
Validation_set = validation_gen.flow_from_directory(validation_dir,
target_size = Image_size,
batch_size = Batch_size,
class_mode = "binary")
https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.