简体   繁体   中英

How to pass the weights of previous layers as input to a call function,of a custum layer, in a functional Keras Model?

For the call method of my custom layer I need the weights of some precedent layers, but I don't need to modify them only access to their value. I have the value as suggest in How do I get the weights of a layer in Keras? but this returns weights as numpy array. So I have cast them in Tensor (using tf.convert_to_tensor from Keras backend) but, in the moment of the creation of the model I have this error "'NoneType' object has no attribute '_inbound_nodes'". How can I fix this problem? Thanks you.

TensorFlow provides graph collections that group the variables. To access the variables that were trained you would call tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) or its shorthand tf.trainable_variables() or to get all variables (including some for statistics) use tf.get_collection(tf.GraphKeys.VARIABLES) or its shorthand tf.all_variables()

tvars = tf.trainable_variables()
tvars_vals = sess.run(tvars)

for var, val in zip(tvars, tvars_vals):
    print(var.name, val)  # Prints the name of the variable alongside its value.

You can pass this precedent layer while initializing your custom layer class.

Custom Layer:

class CustomLayer(Layer):
    def __init__(self, reference_layer):
      super(CustomLayer, self).__init__()
      self.ref_layer = reference_layer # precedent layer

    def call(self, inputs):
        weights = self.ref_layer.get_weights()
        ''' do something with these weights '''
        return something

Now you add this layer to your model using Functional-API .

inp = Input(shape=(5))
dense = Dense(5)
custom_layer= CustomLayer(dense) # pass layer here

#model
x = dense(inp)
x = custom_layer(x)
model = Model(inputs=inp, outputs=x)

Here custom_layer can access weights of layer dense .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM