[英]Nested Gradient Tape in function (TF2.0)
我嘗試實現 MAML。 因此,我需要我的 model (model_copy) 的副本進行一步訓練,然后我需要在丟失 model_copy 的情況下訓練我的 meta_model。
我想在 function 中訓練 model_copy。 如果我將我的代碼復制到 function 我沒有得到正確的 gradients_meta (它們都沒有)。
看來,這些圖是未連接的 - 我如何連接這些圖?
知道我做錯了什么嗎? 我觀察了很多變量,但這似乎並沒有什么不同。
這是重現此問題的代碼:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as keras_backend
def copy_model(model):
copied_model = keras.Sequential()
copied_model.add(keras.layers.Dense(5, input_shape=(1,)))
copied_model.add(keras.layers.Dense(1))
copied_model.set_weights(model.get_weights())
return copied_model
def compute_loss(model, x, y):
logits = model(x) # prediction of my model
mse = keras_backend.mean(keras.losses.mean_squared_error(y, logits)) # compute loss between prediciton and label/truth
return mse, logits
# meta_model to learn in outer gradient tape
meta_model = keras.Sequential()
meta_model.add(keras.layers.Dense(5, input_shape=(1,)))
meta_model.add(keras.layers.Dense(1))
# optimizer for training
optimizer = keras.optimizers.Adam()
# function to calculate model_copys params
def do_calc(x, y, meta_model):
with tf.GradientTape() as gg:
model_copy = copy_model(meta_model)
gg.watch(x)
gg.watch(meta_model.trainable_variables)
gg.watch(model_copy.trainable_variables)
loss, _ = compute_loss(model_copy, x, y)
gradient = gg.gradient(loss, model_copy.trainable_variables)
optimizer.apply_gradients(zip(gradient, model_copy.trainable_variables))
return model_copy
# inputs for training
x = tf.constant(3.0, shape=(1, 1, 1))
y = tf.constant(3.0, shape=(1, 1, 1))
with tf.GradientTape() as g:
g.watch(x)
g.watch(y)
model_copy = do_calc(x, y, meta_model)
g.watch(model_copy.trainable_variables)
# calculate loss of model_copy
test_loss, _ = compute_loss(model_copy, x, y)
# build gradients for meta_model update
gradients_meta = g.gradient(test_loss, meta_model.trainable_variables)
# gradients always None !?!!11 elf
optimizer.apply_gradients(zip(gradients_meta, meta_model.trainable_variables))
預先感謝您的任何幫助。
我找到了一個解決方案:我需要以某種方式“連接”元模型和模型副本。
任何人都可以解釋為什么這有效,以及我將如何使用“適當的”優化器來實現這一目標?
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as keras_backend
def copy_model(model):
copied_model = keras.Sequential()
copied_model.add(keras.layers.Dense(5, input_shape=(1,)))
copied_model.add(keras.layers.Dense(1))
copied_model.set_weights(model.get_weights())
return copied_model
def compute_loss(model, x, y):
logits = model(x) # prediction of my model
mse = keras_backend.mean(keras.losses.mean_squared_error(y, logits)) # compute loss between prediciton and label/truth
return mse, logits
# meta_model to learn in outer gradient tape
meta_model = keras.Sequential()
meta_model.add(keras.layers.Dense(5, input_shape=(1,)))
meta_model.add(keras.layers.Dense(1))
# optimizer for training
optimizer = keras.optimizers.Adam()
# function to calculate model_copys params
def do_calc(meta_model, x, y, gg, alpha=0.01):
model_copy = copy_model(meta_model)
loss, _ = compute_loss(model_copy, x, y)
gradients = gg.gradient(loss, model_copy.trainable_variables)
k = 0
for layer in range(len(model_copy.layers)):
# calculate adapted parameters w/ gradient descent
# \theta_i' = \theta - \alpha * gradients
model_copy.layers[layer].kernel = tf.subtract(meta_model.layers[layer].kernel,
tf.multiply(alpha, gradients[k]))
model_copy.layers[layer].bias = tf.subtract(meta_model.layers[layer].bias,
tf.multiply(alpha, gradients[k + 1]))
k += 2
return model_copy
with tf.GradientTape() as g:
# inputs for training
x = tf.constant(3.0, shape=(1, 1, 1))
y = tf.constant(3.0, shape=(1, 1, 1))
adapted_models = []
# model_copy = meta_model
with tf.GradientTape() as gg:
model_copy = do_calc(meta_model, x, y, gg)
# calculate loss of model_copy
test_loss, _ = compute_loss(model_copy, x, y)
# build gradients for meta_model update
gradients_meta = g.gradient(test_loss, meta_model.trainable_variables)
# gradients work. Why???
optimizer.apply_gradients(zip(gradients_meta, meta_model.trainable_variables))
將Tensor轉換為numpy並使用set_weights()只會復制更新后的梯度參數值,但是tf2圖中的節點名稱發生了變化,所以不能直接使用復制的損失model來求梯度元 model
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.