简体   繁体   中英

Get batch size in Keras custom layer and use tensorflow operations (tf.Variable)

I would like to write a Keras custom layer with tensorflow operations, that require the batch size as input. Apparently I'm struggling in every nook and cranny.

Suppose a very simple layer: (1) get batch size (2) create a tf.Variable (let's call it my_var) based on the batch size, then some tf.random ops to alter my_var (3) finally, return input multiplied with my_var

What I tried so far:

class TestLayer(Layer):

    def __init__(self, **kwargs):

        self.num_batch = None
        self.my_var = None

        super(TestLayer, self).__init__(**kwargs)

    def build(self, input_shape):

        self.batch_size = input_shape[0]

        var_init = tf.ones(self.batch_size, dtype = x.dtype)
        self.my_var = tf.Variable(var_init, trainable=False, validate_shape=False)

        # some tensorflow random operations to alter self.my_var

        super(TestLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):

        return self.my_var * x

    def compute_output_shape(self, input_shape):

        return input_shape

Now creating a very simple model:

# define model
input_layer = Input(shape = (2, 2, 3), name = 'input_layer')
x = TestLayer()(input_layer)

# connect model
my_mod = Model(inputs = input_layer, outputs = x)
my_mod.summary()

Unfortunately, what ever I try/change in the code, I get multiple errors, most of them with very cryptical tracebacks (ValueError: Cannot convert a partially known TensorShape to a Tensor: or ValueError: None values not supported.).

Any general suggestions? Thanks in advance.

You need to specify batch size if you want to create a variable of size batch_size . Additionally, if you want to print a summary the tf.Variable must have a fixed shape ( validatate_shape=True ) and it must be broadcastable to be successfully multiplied by the input:

import tensorflow as tf
from tensorflow.keras.layers import Layer, Input
from tensorflow.keras.models import Model

class TestLayer(Layer):

    def __init__(self, **kwargs):
        self.num_batch = None
        self.my_var = None
        super(TestLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.batch_size = input_shape[0]
        var_init = tf.ones(self.batch_size, dtype=tf.float32)[..., None, None, None]
        self.my_var = tf.Variable(var_init, trainable=False, validate_shape=True)
        super(TestLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        res = self.my_var * x
        return res

    def compute_output_shape(self, input_shape):
        return input_shape

# define model
input_layer = Input(shape=(2, 2, 3), name='input_layer', batch_size=10)
x = TestLayer()(input_layer)

# connect model
my_mod = Model(inputs=input_layer, outputs=x)
my_mod.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_layer (InputLayer)     (10, 2, 2, 3)             0         
_________________________________________________________________
test_layer (TestLayer)       (10, 2, 2, 3)             0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM