繁体   English   中英

为什么损失 function 在 tf.GradientTape 块内部而梯度计算在外部?

[英]Why loss function is inside tf.GradientTape block and gradients calculation is outside?

我是Tensorflow的新手。 在教科书示例中,我看到以下代码旨在使用Tensorflow 2.x API 训练简单的线性 model:

m = tf.Variable(0.)
b = tf.Variable(0.)
def predict_y_value(x):
    y = m * x + b
    return y
def squared_error(y_pred, y_true):
    return tf.reduce_mean(tf.square(y_pred - y_true))
learning_rate = 0.05
steps = 500
for i in range(steps):
    with tf.GradientTape() as tape:
        predictions = predict_y_value(x_train)
        loss = squared_error(predictions, y_train)
    gradients = tape.gradient(loss, [m, b])
    m.assign_sub(gradients[0] * learning_rate)
    b.assign_sub(gradients[1] * learning_rate)
print ("m: %f, b: %f" % (m.numpy(), b.numpy()))

为什么需要将 loss function 的定义包含在with tf.GradientTape() as tape ,但gradients = tape.gradient(loss, [m, b])代码行在块with

我知道它可能是Python语言特定的,但这种结构对我来说似乎不清楚。 Python上下文管理器在这里的作用是什么?

来自 tensorflow 文档,

默认情况下,GradientTape 将自动监视在上下文中访问的任何可训练变量。

直观地说,这种方法极大地提高了灵活性。 例如,它允许您编写(伪)代码如下:

inputs, labels = get_training_batch()
inputs_preprocessed = some_tf_ops(inputs)
with tf.GradientTape() as tape:
    pred = model(inputs_preprocessed)
    loss = compute_loss(labels, pred)

grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))

# For example, let's attach a model that takes the above model's output as input
next_step_inputs, next_step_labels = process(pred)

with tf.GradientTape() as tape:
    pred = another_model(next_step_inputs)
    another_loss = compute_loss(next_step_labels, pred)

grads = tape.gradient(another_loss, another_model.trainable_variables)
optimizer.apply_gradients(zip(grads, another_model.trainable_variables))

上面的例子可能看起来很复杂,但它解释了需要极大灵活性的极端情况。

  1. 我们不希望some_tf_opsprocess在梯度计算中发挥作用,因为它们是预处理步骤。

  2. 我们想计算多个模型的梯度,有一些关系

一个实际的例子是训练 GAN,尽管更简单的实现是可能的。

tape.gradient放在TapeGradient()之外会重置上下文并为垃圾收集器释放资源。

注 2 等效示例:

with tf.GradientTape() as t:
  loss = loss_fn()
with tf.GradientTape() as t:
  loss += other_loss_fn()
t.gradient(loss, ...)         # Only differentiates other_loss_fn, not loss_fn

下面等价于上面

with tf.GradientTape() as t:
  loss = loss_fn()
  t.reset()
  loss += other_loss_fn()
t.gradient(loss, ...)         # Only differentiates other_loss_fn, not loss_fn 

资源

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM