简体   繁体   English

如何在keras子类化模型中使用回调?

[英]How to use callbacks in keras subclassed model?

I instantiate keras subclassed model, as tensorflow guide , 我实例化了keras子模型,作为tensorflow指南

To be able to use .fit to my model class, I need to create compute_output_shape . 为了能够对我的模型类使用.fit ,我需要创建compute_output_shape Nevertheless, using callbacks for fitted model throw NotImplementedError . 但是,对拟合模型使用回调会抛出NotImplementedError

So, what can I do to use callbacks in keras subclassed model, such as tensorboard, checkpoints, etc ? 那么,我该怎么做才能在keras子类化模型中使用回调,例如tensorboard,检查点等?

You can try this: 您可以尝试以下方法:

model = SubclassModel()
# Callbacks you can define
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='model')
# Add callbacks to fit parameter
model.fit(data, labels, batch_size=100, epochs=5, callbacks=[tensorboard_callback])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM