简体   繁体   English

直接使用梯度急切地更新 keras 模型的权重

[英]Eagerly update a keras model's weights directly using the gradient

I am writing a custom optimizer with Eager Execution in Ternsorflow 1.15 but can't figure out how to update the weights.我正在 Ternsorflow 1.15 中使用 Eager Execution 编写自定义优化器,但无法弄清楚如何更新权重。 Taking gradient descent as an example, I have the weights, the gradient and a scalar learning rate but can't figure out how to combine them.以梯度下降为例,我有权重、梯度和标量学习率,但不知道如何组合它们。

This is an implementation of gradient descent where model is a keras.Model eg a multilayer CNN:这是梯度下降的实现,其中模型是keras.Model例如多层 CNN:

lr = tf.constant(0.01)

def minimize(model, inputs, targets):
    with tf.GradientTape() as tape:
        logits = model(input)
        loss_value = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=targets)

    grad = tape.gradient(loss_value, model.trainable_variables)
    step = tf.multiply(self.lr, grad)
    model.trainable_variables.assign_sub(step)

but it fails on the tf.multiply saying但它在tf.multiply上失败了

tensorflow.python.framework.errors_impl.InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [5,5,1,6] != values[1].shape = [6] [Op:Pack] name: packed

I also know the last line will fail as trainable_variables is a list and doesn't have the method assign_sub .我也知道最后一行会失败,因为trainable_variables是一个列表并且没有方法assign_sub


How can I rewrite the last two lines of my code to do:我怎样才能重写我的代码的最后两行:

model.trainable_variables -= lr * grad

Figured it out.弄清楚了。 As both are lists we need to iterate through their pairs of gradients and variables for each layer together and update each of these separately.由于两者都是列表,我们需要一起迭代每层的梯度和变量对,并分别更新它们中的每一个。

lr = tf.constant(0.01)

def minimize(model, inputs, targets):
    with tf.GradientTape() as tape:
        logits = model(input)
        loss_value = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=targets)

    grad = tape.gradient(loss_value, model.trainable_variables)
    for v, g in zip(model.trainable_variables, grad):
        v.assign_sub(lr * g)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM