簡體   English   中英

將圖像數據集拆分為 CNN 的訓練測試拆分

[英]Split dataset of images into train test split for CNN

我正在 kaggle 上訓練一個 CNN,我的數據由兩部分組成:1 個 csv 標簽文件和 1 個圖像文件夾。 如何將 kaggle 上的數據拆分為訓練測試拆分? 謝謝。

在此處輸入圖像描述

這是一個示例圖像:

在此處輸入圖像描述

和相關的標簽(來自 csv):

在此處輸入圖像描述

下面的 function 給出了創建訓練、測試和驗證生成器: source dir - 包含所有圖像的目錄的完整路徑 cvs_path - CSV 文件的路徑,該文件具有包含文件名字符串的列 ( x_col ) 和列 ( y_col ) 包含 class 相關文件名的字符串

note: source_dir/filename results in a path to the file in the source_dir This function automatically determines the batch_size for the generator and steps to us in model.fit so that you go through the train, test, or validation images exactly once per epoch. max_batch_size指定基於 memory 約束允許的最大批量大小 train_split - 在 0 和 1 之間浮動,指定用於訓練的圖像百分比 test_split - 在 0 和 1 之間浮動,指定用於訓練的圖像百分比 注意 validation_split 在內部計算為 1 - train_split - test_split target_size= tuple(height, width) 輸入圖像按比例調整 - 浮點像素被重新調整為像素*比例(通常為 1/255) class_mode - 請參閱 keras flow_from_dataframe 了解詳細信息,通常使用“分類”

import os
import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
def train_test_valid_split(source_dir, cvs_path,max_batch_size, train_split, test_split, x_col, y_col, class_mode, target_size, scale):
    data=pd.read_csv(cvs_path).copy()
    te_split=test_split/(1-train_split)    
    train_df=data.sample(n=None, frac=train_split, replace=False, weights=None, random_state=123, axis=0)
    tr_batch_size= max_batch_size
    tr_steps=int(len(train_df.index)//tr_batch_size)     
    dummy_df=data.drop(train_df.index, axis=0, inplace=False)     
    test_df=dummy_df.sample(n=None, frac=te_split, replace=False, weights=None, random_state=123, axis=0)
    te_batch_size, te_steps=get_bs(len(test_df.index),max_batch_size )    
    valid_df=dummy_df.drop(test_df.index, axis=0)
    v_batch_size,v_steps=get_bs(len(valid_df.index), max_batch_size)
    gen=ImageDataGenerator(rescale=scale)
    train_gen=gen.flow_from_dataframe(dataframe=train_df, directory=source_dir,batch_size=tr_batch_size, x_col=x_col, y_col=y_col,
                                      target_size=target_size, class_mode=class_mode,seed=123,  validate_filenames=False)    
    test_gen=gen.flow_from_dataframe(dataframe=test_df, directory=source_dir, batch_size=te_batch_size, x_col=x_col, y_col=y_col,
                                     target_size=target_size, class_mode=class_mode,  shuffle=False,validate_filenames=False)
    valid_gen=gen.flow_from_dataframe(dataframe=valid_df, directory=source_dir,batch_size=v_batch_size, x_col=x_col, y_col=y_col, 
                                      target_size=target_size, class_mode=class_mode, shuffle=False,validate_filenames=False)    
    return train_gen, tr_steps, test_gen, te_steps, valid_gen , v_steps

def get_bs(length, b_max):
    batch_size=sorted([int(length/n) for n in range(1,length+1) if length % n ==0 and length/n<=b_max],reverse=True)[0]
    steps=int(length//batch_size)
    return batch_size, steps

CSV 文件的格式為

    file_id     class_id
0   00000.jpg   AFRICAN CROWNED CRANE
1   00001.jpg   AFRICAN CROWNED CRANE
2   00002.jpg   AFRICAN CROWNED CRANE
3   00003.jpg   AFRICAN CROWNED CRANE
4   00004.jpg   AFRICAN CROWNED CRANE
5   00005.jpg   AFRICAN CROWNED CRANE
6   00006.jpg   AFRICAN CROWNED CRANE
7   00007..jpg  AFRICAN CROWNED CRANE
8   00008..jpg  AFRICAN CROWNED CRANE

下面是一個使用示例

source_dir=r'c:\temp\birds\consolidated_images'
cvs_path=r'c:\temp\birds\birds.csv'
train_split=.8
test_split=.1
x_col='file_id'
y_col='class_id'
target_size=(224,224)
scale=1/127.5-1
max_batch_size=32
class_mode='categorical'
train_gen, train_steps, test_gen, test_steps, valid_gen, valid_steps=train_test_valid_split(source_dir,
                cvs_path, max_batch_size, train_split, test_split, x_col, y_col, class_mode, target_size, scale)
print ('train steps: ', train_steps, '  test steps: ', test_steps, '  valid steps: ', valid_steps)

執行結果是

Found 30172 non-validated image filenames belonging to 250 classes.
Found 3772 non-validated image filenames belonging to 250 classes.
Found 3771 non-validated image filenames belonging to 250 classes.
train steps:  942   test steps:  164   valid steps:  419

現在使用這些生成器

epochs= 20 # set to what you want
history=model.fit(x=train_gen, epochs=epochs,steps_per_epoch=train_steps,   
            validation_data=valid_gen, validation_steps=valid_steps,
            shuffle=False,  verbose=1)

訓練結束后

accuracy=model.evaluate(test_gen, steps=test_steps)[1]*100
print ('Model accuracy on test set is', accuracy)

或做預測

predictions=model.predict(test_gen, steps=test_steps, verbose=1)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM