繁体   English   中英

Keras - 如何正确使用 fit() 来训练 model?

[英]Keras - How to properly use fit() to train a model?

我一直在熟悉 Keras 并从文档开始,我整理了一个基本的 model 并加载了我自己的图像文件夹来训练而不是使用 mnist 数据集。 我已经到了设置 model 的地步,但我不确定如何从那里开始使用 fit() 方法调用我的数据集,然后训练 model 进行预测。 这是到目前为止的代码:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers.experimental.preprocessing import CenterCrop
from tensorflow.keras.layers.experimental.preprocessing import Rescaling
from tensorflow.keras import layers

#Importing the dataset and setting the path
dataset = keras.preprocessing.image_dataset_from_directory(
    'PetImages',
    batch_size = 64,
    image_size = (200, 200)
)

dataset = keras.Input(shape = (None, None, 3))

# PreProcessing layers to better format the datset 
x = CenterCrop(height=150, width=150)(dataset)
x = Rescaling(scale=1.0 / 255)(x)

# Convolution and Pooling Layers
x = layers.Conv2D(filters=32, kernel_size=(3, 3), activation="relu")(x)
x = layers.MaxPooling2D(pool_size=(3, 3))(x)
x = layers.Conv2D(filters=32, kernel_size=(3, 3), activation="relu")(x)
x = layers.MaxPooling2D(pool_size=(3, 3))(x)
x = layers.Conv2D(filters=32, kernel_size=(3, 3), activation="relu")(x)

# Global average pooling to get flat feature vectors
x = layers.GlobalAveragePooling2D()(x)

# Adding a dense classifier 
num_classes = 10
outputs = layers.Dense(num_classes, activation="softmax")(x)

# Instantiates the model once layers have been set
model = keras.Model(inputs = dataset, outputs = outputs)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')

# Problem: Unsure how to further call on dataset to train the model and make a prediction
model.fit()

.fit()方法是实际训练您的网络的方法,以便它以您希望它训练的方式运行。 模型需要数据才能进行训练。 我会查看他们的 文档或他们的一些示例,以了解如何将 go 与您的 model 作为一个良好的起点。

Depending on the version of tensorflow.keras that you are using, .fit can either take two positional arguments x , and y or it can take a generator object, which is something that acts like a continuously active function. 您可能还想设置一个batch_size ,它本质上是一次评估多少个样本。 同样,文档将包含更多关于它可以采用哪种参数的信息。

在您的情况下,您似乎在变量dataset获得了一些好的输入图像(您会立即覆盖),但您没有标签 标签定义了您输入的训练图像的预期 output 是什么。 您需要的第一步是一组标签,然后,您可以对代码进行一些调整以使其运行:

# Add line below
labels = # ... load labels from someplace, like how you loaded the images

# Change this
dataset = keras.Input(shape = (None, None, 3))
# to this
input_layer = layers.Input(shape=(None, None, 3))

# Change this
model = keras.Model(inputs = dataset, outputs = outputs)
# to this
model = keras.models.Model(inputs = input_layer, outputs = outputs)

# and finally you can fit your model using
model.fit(dataset, labels)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM