简体   繁体   中英

Get a keras model to output a result and another using ma of the weights

Given two keras models model1 and model2 with identical architectures, I need to train the first using the model weights and the second using the moving average of the model weights. Here's an example to illustrate:

from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow_addons.optimizers import MovingAverage
import tensorflow as tf


model1 = Model(...)
model2 = tf.keras.models.clone_model(model1)
opt1 = Adam()
opt2 = MovingAverage(Adam())
model1.compile(optimizer=opt1)
model2.compile(optimizer=opt2)
with tf.GradientTape() as tape, tf.GradientTape() as tape2:
    loss = calculate_loss()  # the loss is the same
grads1 = tape.gradient(loss, model1.trainable_variables)
grads2 = tape2.gradient(loss, model2.trainable_variables)
model1.optimizer.apply_gradients(zip(grads1, model1.trainable_variables))
model2.optimizer.apply_gradients(zip(grads2, model2.trainable_variables))

After each gradient update, both models will be called on the same input to output separate values.

v1 = model1(inp)
v2 = model2(inp)

Is it possible to get rid of the dual logic tape1 and tape2 , grads1 and grads2 ... by merging both models to somehow output both results from the weights and the averaged weights?

Basically, you could create two copies of the same network under one model, but under different name scopes, and then at optimization time, use one optimizer to update your regular weights, and have another optimizer only update your moving average weights.

Data

import numpy as np
import tensorflow as tf

from tensorflow_addons.optimizers import MovingAverage
from tensorflow.keras.optimizers import Adam


# fake data
X = tf.random.normal([1000, 128])
y = tf.one_hot(
    tf.random.uniform(
        [1000, ],
        minval=0,
        maxval=3,
        dtype=tf.int64), 3)

Custom Model

# custom model with weights under specific name scopes
class DualWeightModel(tf.keras.Model):
    def __init__(self, num_units=256):
        super().__init__()
        self.num_units = num_units
        self.x_r = tf.keras.layers.Dense(self.num_units)
        self.l_r = tf.keras.layers.Dense(3, activation="softmax")
        self.x_ma = tf.keras.layers.Dense(self.num_units)
        self.l_ma = tf.keras.layers.Dense(3, activation="softmax")
      
    def call(self, x):
        with tf.name_scope("regular"):
            out_r = self.l_r(self.x_r(x))
        with tf.name_scope("ma"):
            out_ma = self.l_ma(self.x_ma(x))
        return out_r, out_ma
  
  
# loss function
def calc_loss(y_true, y_pred):
    return tf.keras.losses.CategoricalCrossentropy()(y_true, y_pred)

Optimization

# optimizers
opt_r = Adam(1e-4)
opt_ma = MovingAverage(Adam(1e-4))


# instantiate model
model = DualWeightModel()


# define one train step
def train_step(X, y):
    # forward pass
    with tf.GradientTape(persistent=True) as tape:
        y_hat_r, y_hat_ma = model(X)
        r_loss = calc_loss(y, y_hat_r)
        ma_loss = calc_loss(y, y_hat_ma)
        
    # get trainable variables under each name scope
    r_vars = []
    ma_vars = []
    
    for v in model.trainable_variables:
        if 'regular' in v.name:
            r_vars.append(v)
        if 'ma' in v.name:
            ma_vars.append(v)
    
    # optimize
    r_grads = tape.gradient(r_loss, r_vars)
    ma_grads = tape.gradient(ma_loss, ma_vars)

    opt_r.apply_gradients(zip(r_grads, r_vars))
    opt_ma.apply_gradients(zip(ma_grads, ma_vars))
    
    return r_loss, ma_loss

Train Model

# train
train_iter = iter(tf.data.Dataset.from_tensor_slices((X, y)).batch(32))

for epoch in range(10):
    r_losses, ma_losses = [], []
    for batch in range(100):
        X_train, y_train = next(train_iter)
        r_loss, ma_loss = train_step(X_train, y_train)
        r_losses.append(r_loss)
        ma_losses.append(ma_loss)
        
        if batch % 5 == 0:
            msg = (f"r_loss: {np.mean(r_losses):.4f} "
                   f"\tma_loss: {np.mean(ma_losses):.4f}")
            print(msg)
            r_losses = []
            ma_losses = []

# r_loss: 1.6749    ma_loss: 1.7274
# r_loss: 1.4319    ma_loss: 1.6590
# ...

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM