简体   繁体   中英

Keras model.fit() with tf.dataset API + validation_data

So I have got my keras model to work with a tf.Dataset through the following code:

# Initialize batch generators(returns tf.Dataset)
batch_train = build_features.get_train_batches(batch_size=batch_size)

# Create TensorFlow Iterator object
iterator = batch_train.make_one_shot_iterator()
dataset_inputs, dataset_labels = iterator.get_next()

# Create Model
logits = .....(some layers)
keras.models.Model(inputs=dataset_inputs, outputs=logits)

# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit(epochs=epochs, steps_per_epoch=num_batches, callbacks=callbacks, verbose=1)

however when I try to pass validation_data parameter to the model. fit it tells me that I cannot use it with the generator. Is there a way to use validation while using tf.Dataset

for example in tensorflow I could do the following :

# initialize batch generators
batch_train = build_features.get_train_batches(batch_size=batch_size)
batch_valid = build_features.get_valid_batches(batch_size=batch_size)

# create TensorFlow Iterator object
iterator = tf.data.Iterator.from_structure(batch_train.output_types,
                                           batch_train.output_shapes)

# create two initialization ops to switch between the datasets
init_op_train = iterator.make_initializer(batch_train)
init_op_valid = iterator.make_initializer(batch_valid)

then just use sess.run(init_op_train) and sess.run(init_op_valid) to switch between the datasets

I tried implementing a callback that does just that (switch to validation set, predict and back) but it tells me I can't use model.predict in a callback

can someone help me get validation working with Keras+Tf.Dataset

edit: incorporate answer into the code

so FINALLY what worked for me, thanks to the selected answer is:

# Initialize batch generators(returns tf.Dataset)
batch_train = # returns tf.Dataset
batch_valid = # returns tf.Dataset

# Create TensorFlow Iterator object and wrap it in a generator
itr_train = make_iterator(batch_train)
itr_valid = make_iterator(batch_train)

# Create Model
logits = # the keras model
keras.models.Model(inputs=dataset_inputs, outputs=logits)

# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit_generator(
    generator=itr_train, validation_data=itr_valid, validation_steps=batch_size,
    epochs=epochs, steps_per_epoch=num_batches, callbacks=cbs, verbose=1, workers=0)

def make_iterator(dataset):
    iterator = dataset.make_one_shot_iterator()
    next_val = iterator.get_next()

    with K.get_session().as_default() as sess:
        while True:
            *inputs, labels = sess.run(next_val)
            yield inputs, labels

This doesn't introduce any overhead

I solved the problem by using fit_genertor. I found the solution here . I applied @Dat-Nguyen's solution.

You need simply to create two iterators, one for training and one for validation and then create your own generator where you will extract batches from the dataset and provide the data in form of (batch_data, batch_labels) . Finally in model.fit_generator you will pass the train_generator and validation_generator.

The way to connect a reinitializable iterator to a Keras model is to plug in an Iterator that returns both the x and y values concurrently:

sess = tf.Session()
keras.backend.set_session(sess) 

x = np.random.random((5, 2))
y = np.array([0, 1] * 3 + [1, 0] * 2).reshape(5, 2) # One hot encoded
input_dataset = tf.data.Dataset.from_tensor_slices((x, y))

# Create your reinitializable_iterator and initializer
reinitializable_iterator = tf.data.Iterator.from_structure(input_dataset.output_types, input_dataset.output_shapes)
init_op = reinitializable_iterator.make_initializer(input_dataset)

#run the initializer
sess.run(init_op) # feed_dict if you're using placeholders as input

# build keras model and plug in the iterator
model = keras.Model.model(...)
model.compile(...)
model.fit(reinitializable_iterator,...)

If you also have a validation dataset, the easiest thing to do is to just create a separate iterator and plug it in the validation_data parameter. Make sure to define your steps_per_epoch and validation_steps since they cannot be inferred.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM