简体   繁体   中英

How does Tensorflow build() work from tf.keras.layers.Layer

I was wondering if anyone knew how the build() function works from the tf.keras.layers.Layer class under the hood. According to the documentation :

build is called when you know the shapes of the input tensors and can do the rest of the initialization

so to me it seems like the class is behaving similar to this:

class MyDenseLayer:
  def __init__(self, num_outputs):
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.kernel = self.add_weight("kernel",
                                  shape=[int(input_shape[-1]), self.num_outputs])

  def __call__(self, input):
    self.build(input.shape) ## build is called here when input shape is known
    return tf.matmul(input, self.kernel)

I can't imagine build() would be called for ever __call__ , but it is the only place where the input is passed in. Does anyone know how exactly this works under the hood?

The Layer.build() method is typically used to instantiate the weights of the layer. See the source code for tf.keras.layers.Dense for an example, and note that the weight and bias tensors are created in that function. The Layer.build() method takes an input_shape argument, and the shape of the weights and biases often depend on the shape of the input.

The Layer.call() method, on the other hand, implements the forward-pass of the layer. You do not want to overwrite __call__ , because that is implemented in the base class tf.keras.layers.Layer . In a custom layer, you should implement call() .

Layer.call() does not call Layer.build() . However, Layer().__call__() does call it if the layer has not been built yet ( source ), and that will set an attribute self.built = True to prevent Layer.build() from being called again. In other words, Layer.__call__() only calls Layer.build() the first time it is called.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM