简体   繁体   中英

How to run training with pre-made dense layers in tensorflow 2.0?

I am in the process of re-writing code that is compatible with TF 2.0. Unfortunately, almost every example provided by the website uses the keras API. I, however, want to write code with raw tensorflow functions.

At some point, the new way of calculating and applying gradients during the training process looks something like this (code stolen from here ):

# Optimization process. 
def run_optimization(x, y):
    # Wrap computation inside a GradientTape for automatic differentiation.
    with tf.GradientTape() as g:
        pred = logistic_regression(x)
        loss = cross_entropy(pred, y)

    # Compute gradients.
    gradients = g.gradient(loss, [W, b])

    # Update W and b following gradients.
    optimizer.apply_gradients(zip(gradients, [W, b]))

The thing that causes problems here is the fact that I have to specify the trainable variables. In this particular case, it is easy because W and b have been created manually. It's also easy when using a keras model through the use of model.trainable_variables .

In my model, I am using dense layers provided by tensorflow, eg tf.keras.layers.Dense . The function provided in tensorflow 1.x for this usecase was tf.trainable_variables() , but it does not exist anymore.

How do I access their internal weights to pass them to the GradientTape?

All Keras layers have a property trainable_variables which you can use to access them. There's also trainable_weights but in most cases the two are identical. Note that this will actually be an empty list until the layer has been built, which you can do by calling layer.build(input_shape) . Alternatively, a layer will be built the first time it is called on an input.

You have an equivalent of the tf.trainable_variables() in Keras, which is tf.keras.layers.Layer.trainable_variables .

Here is a more real example on how to use gradient tape with Keras model.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM