简体   繁体   中英

TensorFlow Keras: tf.keras.Model train_on_batch vs make_train_function - Why is one slower than the other?

In the TF2.2 release candidate, another way of training is to generate a training function with training_function = tf.keras.Model.make_train_function() which when called will perform one training step.

training_function(data)

Another way of training is to use tf.keras.Model.train_on_batch(data) . However, I find that there is a performance difference of around 25% time.

I was wondering if there was any reason why using the training_function=tf.keras.Model.make_train_function();training_function(data) method for training would be faster than tf.keras.Model.train_on_batch() ?

(Some other details: I installed TF using conda, so I'm actually using TF2.1 and I've implemented the TF2.2. make_train_function() by using the " _make_train_function() " (note the underscore) in TF2.1:

from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras import backend as K
model = ... # Some Keras model
data = ... # Some TF dataset
_,_,_ = model._standardize_user_data(data, None)
training_function = model._make_train_function()

data,data_y_None,data_sampleweights_None = model._standardize_user_data(data, None)
data = training_utils.ModelInputs(data).as_list()
data = data + list(data_y_None or []) + list(data_sampleweights_None or [])
if not isinstance(K.symbolic_learning_phase(), int):
    data += [True]
training_function(data)

I'm not entirely surely why this method of training performs so much faster but training work. Any help will be appreciated:) I'm hoping there is something really obvious that I'm missing)

UPDATE: A quick inspection of the code for both TF2.1 and TF2.2-rc reveals that train_on_batch calls (_)make_train_function each time train_on_batch is called which accounts for the extra 25% of time. The question is now why does train_on_batch recreate the training function each time?

UPDATE: The training function is only created once since the training function is then kept as an object property. However, the training function is re-created (and object property overwritten) if it detects that the model has recompiled since the last call. For some reason, this recompile is being triggered in my code causing the re-creation of the training function on every step and I'm not sure why. Without full diagnostic example, it is hard for you to help, but I'm wondering if it has to do with me using .add_metric() and .add_loss() functions for custom tf.keras.Model().

train_on_batch does call _make_execution_function() each time in tensorflow1.15 . But the execution function is created only once (if not yet created or when the model is recompiled) because of this if statement

if getattr(self, 'train_function', None) is None or has_recompiled:

In any way, _make_train/test/predict_function() are private functions and their purpose is to help developers in the internal implementation, and as you have correctly noticed they are not present in tensorflow2.2 .

In tensorflow2.2 you have make_execution_function() which runs fully only once because of this if statement

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM