简体   繁体   中英

Custom Keras Layer with Trainable Scalars

I'm (trying to) writing a custom Keras layer which implements the following componentwise:

x -> a x + b ReLU(x)

with a and b trainable weights. Here's what I've tries so far:

class Custom_ReLU(tf.keras.layers.Layer):

    def __init__(self, units=d):
        super(Custom_ReLU, self).__init__()
        self.units = units

    def build(self, input_shape):
        self.a1 = self.add_weight(shape=[1],
                                initializer = 'random_uniform',
                                trainable=True)
        self.a2 = self.add_weight(shape=[1],
                                initializer = 'random_uniform',
                                trainable=True)

    def call(self,inputs):
        return self.a1*inputs + self.a2*(tf.nn.relu(inputs))

However, get errors. I think the issue is that I have no clue how to define trainable "scalars"... Am I correct in thinking this and how to do that?

Edit/Additions:

Here is how I'm trying to build my plain-vanilla feed-forward architecture with ReLU replaced by "Custom_ReLU":

# Build Vanilla Network
inputs_ffNN = tf.keras.Input(shape=(d,))
x_ffNN = fullyConnected_Dense(d)(inputs_ffNN)
for i in range(Depth):
    x_HTC = Custom_ReLU(x_ffNN)
    x_ffNN = fullyConnected_Dense(d)(x_ffNN)
outputs_ffNN = fullyConnected_Dense(D)(x_ffNN)
ffNN = tf.keras.Model(inputs_ffNN, outputs_ffNN)

And here is a snippet of the errors:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-27-8bf6fc4ae89d> in <module>
      7     #x_HTC = tf.nn.relu(x_HTC)
      8     x_HTC = BounceLU(x_HTC)
----> 9     x_HTC = HTC(d)(x_HTC)
     10 outputs_HTC = HTC(D)(x_HTC)
     11 ffNN_HTC = tf.keras.Model(inputs_HTC, outputs_HTC)

~/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
    816         # Eager execution on data tensors.
    817         with backend.name_scope(self._name_scope()):
--> 818           self._maybe_build(inputs)
    819           cast_inputs = self._maybe_cast_inputs(inputs)
    820           with base_layer_utils.autocast_context_manager(

~/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in _maybe_build(self, inputs)
   2114         # operations.
   2115         with tf_utils.maybe_init_scope(self):
-> 2116           self.build(input_shapes)
   2117       # We must set self.built since user defined build functions are not
   2118       # constrained to set self.built.

<ipython-input-5-21623825ed35> in build(self, input_shape)
      5 
      6     def build(self, input_shape):
----> 7         self.w = self.add_weight(shape=(input_shape[-1], self.units),
      8                                initializer='random_normal',
      9                                trainable=False)

TypeError: 'NoneType' object is not subscriptable

I have no problem using your layer:

class Custom_ReLU(tf.keras.layers.Layer):

    def __init__(self):
        super(Custom_ReLU, self).__init__()

        self.a1 = self.add_weight(shape=[1],
                                initializer = 'random_uniform',
                                trainable=True)
        self.a2 = self.add_weight(shape=[1],
                                initializer = 'random_uniform',
                                trainable=True)

    def call(self,inputs):
        return self.a1*inputs + self.a2*(tf.nn.relu(inputs))

usage:

d = 5
inputs_ffNN = tf.keras.Input(shape=(d,))
x_ffNN = tf.keras.layers.Dense(10)(inputs_ffNN)
x_HTC = Custom_ReLU()(x_ffNN)
outputs_ffNN = tf.keras.layers.Dense(1)(x_HTC)

ffNN = tf.keras.Model(inputs_ffNN, outputs_ffNN)
ffNN.compile('adam', 'mse')

ffNN.fit(np.random.uniform(0,1, (10,5)), np.random.uniform(0,1, 10), epochs=10)

here the full example: https://colab.research.google.com/drive/1n4jIsY3qEDvtobofQaUPO3ysUW9bQWjs?usp=sharing

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM