简体   繁体   中英

TensorFlow - Difference between tf.keras.layers.Layer vs tf.keras.Model

Reading through the documentation of implementing custom layers with tf.keras , they specify two options to inherit from, tf.keras.Layer and tf.keras.Model .

Under the context of creating custom layers , I'm asking myself what is the difference between these two? Technically what is different?

If I were to implement the transformer encoder for example, which one would be more suitable? (assuming the transformer is a only a "layer" in my full model)

In the documentation:

The Model class has the same API as Layer, with the following differences: - It exposes built-in training, evaluation, and prediction loops (model.fit(), model.evaluate(), model.predict()). - It exposes the list of its inner layers, via the model.layers property. - It exposes saving and serialization APIs.

Effectively, the "Layer" class corresponds to what we refer to in the literature as a "layer" (as in "convolution layer" or "recurrent layer") or as a "block" (as in "ResNet block" or "Inception block").

Meanwhile, the "Model" class corresponds to what is referred to in the literature as a "model" (as in "deep learning model") or as a "network" (as in "deep neural network").

So if you want to be able to call .fit() , .evaluate() , or .predict() on those blocks or you want to be able to save and load those blocks separately or something you should use the Model class. The Layer class is leaner so you won't bloat the layers with unnecessary functionality...but I would guess that that generally wouldn't be a big problem.

  • A layer takes in a tensor and give out a tensor which is a result of some tensor operations
  • A model is a composition of multiple layers.

If you are building a new model architecture using existing keras/tf layers then build a custom model.

If you are implementing your own custom tensor operations with in a layer, then build a custom layer. If you are using non tensor operation inside your custom layer, then you have to code how the layer will forward propagate and backward propagate.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM