简体   繁体   中英

How can I reuse a Dense layer?

I have a network in Tensorflow, and I want to define a function that passes it's input through a tf.layers.dense layer (obviously, the same one). I see the reuse argument, but in order to use it properly it seems I need to keep a global variable just to remember if my function was called already. Is there a cleaner way?

I find tf.layers.Dense cleaner than the above answers. All you need is a Dense object defined beforehand. Then you can reuse it any number of times.

import tensorflow as tf

# Define Dense object which is reusable
my_dense = tf.layers.Dense(3, name="optional_name")

# Define some inputs
x1 = tf.constant([[1,2,3], [4,5,6]], dtype=tf.float32)
x2 = tf.constant([[4,5,6], [7,8,9]], dtype=tf.float32)

# Use the Dense layer
y1 = my_dense(x1)
y2 = my_dense(x2)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    y1 = sess.run(y1)
    y2 = sess.run(y2)
    print(y1)
    print(y2)

In fact tf.layers.dense function internally constructs a Dense object and pass your input to that object. For more details, check the code .

As far as I know, there's no cleaner way. The best we can do is wrap tf.layers.dense into our abstraction and use it as an object, hiding variable scope 's backbone:

def my_dense(*args, **kwargs):
  scope = tf.variable_scope(None, default_name='dense').__enter__()
  def f(input):
    r = tf.layers.dense(input, *args, name=scope, **kwargs)
    scope.reuse_variables()
    return r
  return f

a = [[1,2,3], [4,5,6]]
a = tf.constant(a, dtype=tf.float32)
layer = my_dense(3)
a = layer(a)
a = layer(a)

print(*[[int(a) for a in v.get_shape()] for v in tf.trainable_variables()])
# Prints: "[3, 3] [3]" (one pair of (weights and biases))

You could construct the layer against a constant of the right size and ignore the result.

This way the variable is declared but the operation should be pruned from the the graph.

For example

tf.layers.dense(tf.zeros(1, 128), 3, name='my_layer')

... later
hidden = tf.layers.dense(input, 3, name='my_layer', reuse=True)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM