简体   繁体   中英

How do I modify train_step in order to support a validation set when called in model.fit()?

I'm following this Keras tutorial that explain how to write your custom train_step() function while still being able to call model.fit() in order to train your model:

https://keras.io/guides/customizing_what_happens_in_fit/

model.fit() should be able to support validation_data but I can't understand where to write code that compute custom metrics and custom losses for validation_data. I've decided to write a custom loop but I would like to use fit.

Any ideas?

I completely missed the paragraph of the guide that mentions the the function test_step():

def test_step(self, data):
    # Unpack the data
    x, y = data
    # Compute predictions
    y_pred = self(x, training=False)
    # Updates the metrics tracking the loss
    self.compiled_loss(y, y_pred, regularization_losses=self.losses)
    # Update the metrics.
    self.compiled_metrics.update_state(y, y_pred)
    # Return a dict mapping metric names to current value.
    # Note that it will include the loss (tracked in self.metrics).
    return {m.name: m.result() for m in self.metrics}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM