简体   繁体   English

无法使用 tf.function 训练 model

[英]fail to train a model using tf.function

I tried to use tf.function to decorate the gradient update function as below.我尝试使用tf.function来装饰渐变更新 function 如下。

import tensorflow as tf
from tensorflow.keras import layers, activations, losses
import numpy as np
from tensorflow.keras.utils import plot_model
from tensorflow.keras.utils import Progbar

# generate data
nb_doc = 100
doc_features = np.random.random((nb_doc, 10))
doc_scores = np.random.randint(2, size=nb_doc).astype(np.float32)

class simple_model(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense = [layers.Dense(16, activation=tf.nn.leaky_relu), layers.Dense(8, activation=tf.nn.leaky_relu)]
        self.score = layers.Dense(1, activation='sigmoid')
    
    def call(self, inputs):
        dense_a = self.dense[0](inputs)
        for dense in self.dense[1:]:
            dense_a = dense(dense_a)
        y = self.score(dense_a)
        return y
    
    def build_graph(self):
        x = tf.keras.Input(shape=(10))
        return tf.keras.Model(inputs=x, outputs=self.call(x))
        
batch_size = 1
train = tf.data.Dataset.from_tensor_slices((doc_features, doc_scores)).shuffle(nb_doc).batch(batch_size)

loss_object = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.Adam()

@tf.function
def apply_gradient(optimizer, model, x, y):
    with tf.GradientTape() as tape:
        y_pred = model(x)
        loss_value = loss_object(y, y_pred)
    
    gradients = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))
    
    return y_pred, loss_value


def train_data_for_one_epoch(optimizer, model):
    losses = []
    pb_i = Progbar(np.ceil(nb_doc // batch_size), stateful_metrics=['loss'])
    for step, (x, y) in enumerate(train):
        y_pred, loss_value = apply_gradient(optimizer, model, x, y)
        losses.append(loss_value)

        pb_i.update(step+1, values=[('loss', loss_value)], finalize=False)
    pb_i.update(step+1, values=[('loss', np.mean(losses))], finalize=True)
    return losses

I can train a model successfully the first time I call the code below.第一次调用下面的代码时,我可以成功训练 model。
But it fail when I train to train another model using the same code , with error message ValueError: tf.function-decorated function tried to create variables on non-first call.但是当我使用相同的代码训练另一个 model 时它失败了,错误消息ValueError: tf.function-decorated function tried to create variables on non-first call.

# this succeed
epochs = 5
_model = simple_model()
loss_history = []
for epoch in range(epochs):
    print('Epoch %d/%d'%(epoch+1, epochs))
    losses_train = train_data_for_one_epoch(optimizer, _model)
    loss_history.append(np.mean(losses_train))

# this fail
epochs = 5
_model_2 = simple_model()
loss_history = []
for epoch in range(epochs):
    print('Epoch %d/%d'%(epoch+1, epochs))
    losses_train = train_data_for_one_epoch(optimizer, _model_2)
    loss_history.append(np.mean(losses_train))

This seems to be a known issue as indicated here这似乎是一个已知问题,如此处所示

And the work around is解决方法是

# removed @tf.function decorator
def apply_gradient(optimizer, loss_object, model, x, y):
    with tf.GradientTape() as tape:
        y_pred = model(x)
        loss_value = loss_object(y, y_pred)
    
    gradients = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))
    
    return y_pred, loss_value


def train_data_for_one_epoch(optimizer, loss_object, model):
    losses = []
    
    # added tf.function here
    apply_grads =tf.function(apply_gradient)
    
    pb_i = Progbar(np.ceil(nb_doc // batch_size), stateful_metrics=['loss'])
    for step, (x, y) in enumerate(train):
        y_pred, loss_value = apply_grads(optimizer, loss_object, model, x, y)
        losses.append(loss_value)

        pb_i.update(step+1, values=[('loss', loss_value)], finalize=False)
    pb_i.update(step+1, values=[('loss', np.mean(losses))], finalize=True)
    return losses

Then the two models could train without error.然后这两个模型可以无错误地训练。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM