简体   繁体   English

如何在此代码中实现 train_test_split?

[英]How do I implement train_test_split in this code?

# Load training and validation sets
ds_train_ = image_dataset_from_directory(
    '../input/car-or-truck/train', 
    labels='inferred',
    label_mode='binary',
    image_size=[128, 128],
    interpolation='nearest',
    batch_size=64,
    shuffle=True,
)
ds_valid_ = image_dataset_from_directory(
    '../input/car-or-truck/valid',
    labels='inferred',
    label_mode='binary',
    image_size=[128, 128],
    interpolation='nearest',
    batch_size=64,
    shuffle=False,
)

print(ds_train_)
print(ds_valid_)
# Data Pipeline
def convert_to_float(image, label):
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    return image, label

AUTOTUNE = tf.data.experimental.AUTOTUNE
ds_train = (
    ds_train_
    .map(convert_to_float)
    .cache()
    .prefetch(buffer_size=AUTOTUNE)
)
ds_valid = (
    ds_valid_
    .map(convert_to_float)
    .cache()
    .prefetch(buffer_size=AUTOTUNE)
)

print(ds_train)
print(ds_valid)

The current code is based on 2 files of a train set and a valid set, which are already separated.当前代码基于火车集和有效集的 2 个文件,它们已经分开。 However, I want to edit this code so that it starts with a single file consisting of all images, then uses train_test_split to randomly split the file into train/valid and then implement it.但是,我想编辑此代码,使其从包含所有图像的单个文件开始,然后使用 train_test_split 将文件随机拆分为 train/valid,然后实现它。 How can I implement train_test_split to this code?如何在此代码中实现 train_test_split?

If you are loading the dataset from any directory.. use image_dataset_from_directory from tensorflow如果您从任何目录加载数据集.. 使用 tensorflow 中的 image_dataset_from_directory

And use subset feature并使用子集功能

import tensorflow as tf
from tf.keras.utils import image_dataset_from_directory

path="<put path here>"

training_data=image_dataset_from_directory(
   path,
   image_size=(<put your image size here>),
   batch_size=batch_size,
   validation_split=0.2 ,# as per your need
   subset='training'

)
 validation_data=image_dataset_from_directory(
     path,
     image_size=(<put your image size here>),
     batch_size=batch_size,
     validation_split=0.2 ,# as per your need
     subset='validation'

) )

You can add other args like labels,color mode and many others inside the " ( )".. but it should be same on both place..您可以在“()”中添加其他参数,如标签、颜色模式和许多其他参数。但在这两个地方应该是相同的。

If you have any more query, refer to this Documentation如果您还有任何疑问,请参阅此文档

If you got satisfied or your query is clear... upvote... Otherwise let me know in Comments如果您满意或您的查询很明确......请投票......否则请在评论中告诉我

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM