简体   繁体   中英

Simple TensorFlow (Keras) model using tf.tensor objects of structured data

I have two tensor objects, train and labels . The dataset train has 100 features, and labels has 1 feature. Both train and labels have M entries. Similarly, we have a dev and dev_labels set with the same respective number of features and N entries. After importing Keras from TensorFlow, we creating a neural network as follows:

model = keras.Sequential([
    keras.layers.Flatten(input_shape=[100]),
    keras.layers.Dense(100, activation=tf.nn.relu),
    keras.layers.Dense(10, activation=tf.nn.softmax)
])

model.compile(optimizer='adam', 
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

Now we want to fit the model, with batches of size P for Q epochs.

model.fit(train_X, train_Y, validation_data=(dev_X, dev_Y), epochs=Q, steps_per_epoch=??, validation_steps=??)

After reading the documentation on model.fit, I am still not sure what would be the correct steps_per_epoch or validation_steps here. When using data tensors as input to a model, these parameters must be specified. In this example, what would we specify for steps_per_epoch and validation_steps ?

steps_per_epoch should be roughly equal to number of training examples divided by batch size (default is 32). Similarly validation_steps should be roughly equal to number of validation examples divided by batch size. You can find the documentation here .

steps_per_epoch: Integer or None. Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. When training with input tensors such as TensorFlow data tensors, the default None is equal to the number of samples in your dataset divided by the batch size, or 1 if that cannot be determined.

validation_steps: Only relevant if steps_per_epoch is specified. Total number of steps (batches of samples) to validate before stopping.

In your case they should be

steps_per_epoch = len(train_X) / batch_size
validation_steps = len(dev_X) / batch_size

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM