简体   繁体   中英

How to support masking in custom tf.keras.layers.Layer

I'm implementing a custom tf.keras.layers.Layer that needs to support masking.

Consider the following scenario

embedded = tf.keras.layer.Embedding(input_dim=vocab_size + 1, 
                                    output_dim=n_dims, 
                                    mask_zero=True)
x = MyCustomKerasLayers(embedded)

Now per the documentation

mask_zero : Whether or not the input value 0 is a special "padding" value that should be masked out. This is useful when using recurrent layers which may take variable length input. If this is True then all subsequent layers in the model need to support masking or an exception will be raised . If mask_zero is set to True, as a consequence, index 0 cannot be used in the vocabulary (input_dim should equal size of vocabulary + 1).

I wonder, what does that mean? Looking through TensorFlow's custom layers guide and the tf.keras.layer.Layer documentation it is not clear what should be done to support masking

  1. How do I support masking?

  2. How do I access the mask from the past layer?

  3. Assuming input of (batch, time, channels) or `(batch, time) would the masks look different? What will be their shapes?

  4. How do I pass it on to the next layer?

  1. To support masking one should implement the compute_mask method inside the custom layer

  2. To access the mask, simply add as the second positional argument in the call method the argument mask , and it will be accessible (ex. call(self, inputs, mask=None) )

  3. This cannot be guessed, it is the layer's before responsible to calculate the mask

  4. Once you implemented the compute_mask passing the mask to the next layer happens automatically - excluding the case of model subclassing, which in this case it is up to you to calculate masks and pass them on.

Example:

class MyCustomKerasLayers(tf.keras.layers.Layer):
    def __init__(self, .......):
        ...

    def compute_mask(self, inputs, mask=None):
        # Just pass the received mask from previous layer, to the next layer or 
        # manipulate it if this layer changes the shape of the input
        return mask

    def call(self, input, mask=None):
        # using 'mask' you can access the mask passed from the previous layer

Notice that this example just passes on the mask, if the layer will output a shape different than the one received, you should change the mask accordingly in compute_mask to pass on the correct one

EDIT

Now explanation is also included in thetf.keras masking and padding documentation .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM