[英]Why loss function is inside tf.GradientTape block and gradients calculation is outside?
我是Tensorflow
的新手。 在教科书示例中,我看到以下代码旨在使用Tensorflow 2.x
API 训练简单的线性 model:
m = tf.Variable(0.)
b = tf.Variable(0.)
def predict_y_value(x):
y = m * x + b
return y
def squared_error(y_pred, y_true):
return tf.reduce_mean(tf.square(y_pred - y_true))
learning_rate = 0.05
steps = 500
for i in range(steps):
with tf.GradientTape() as tape:
predictions = predict_y_value(x_train)
loss = squared_error(predictions, y_train)
gradients = tape.gradient(loss, [m, b])
m.assign_sub(gradients[0] * learning_rate)
b.assign_sub(gradients[1] * learning_rate)
print ("m: %f, b: %f" % (m.numpy(), b.numpy()))
为什么需要将 loss function 的定义包含在with tf.GradientTape() as tape
,但gradients = tape.gradient(loss, [m, b])
代码行在块with
?
我知道它可能是Python
语言特定的,但这种结构对我来说似乎不清楚。 Python
上下文管理器在这里的作用是什么?
来自 tensorflow 文档,
默认情况下,GradientTape 将自动监视在上下文中访问的任何可训练变量。
直观地说,这种方法极大地提高了灵活性。 例如,它允许您编写(伪)代码如下:
inputs, labels = get_training_batch()
inputs_preprocessed = some_tf_ops(inputs)
with tf.GradientTape() as tape:
pred = model(inputs_preprocessed)
loss = compute_loss(labels, pred)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
# For example, let's attach a model that takes the above model's output as input
next_step_inputs, next_step_labels = process(pred)
with tf.GradientTape() as tape:
pred = another_model(next_step_inputs)
another_loss = compute_loss(next_step_labels, pred)
grads = tape.gradient(another_loss, another_model.trainable_variables)
optimizer.apply_gradients(zip(grads, another_model.trainable_variables))
上面的例子可能看起来很复杂,但它解释了需要极大灵活性的极端情况。
我们不希望some_tf_ops
和process
在梯度计算中发挥作用,因为它们是预处理步骤。
我们想计算多个模型的梯度,有一些关系
一个实际的例子是训练 GAN,尽管更简单的实现是可能的。
将tape.gradient
放在TapeGradient()
之外会重置上下文并为垃圾收集器释放资源。
注 2 等效示例:
with tf.GradientTape() as t:
loss = loss_fn()
with tf.GradientTape() as t:
loss += other_loss_fn()
t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn
下面等价于上面
with tf.GradientTape() as t:
loss = loss_fn()
t.reset()
loss += other_loss_fn()
t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.