简体   繁体   中英

Is it possible to create a model in Keras Functional API without an input layer?

I would like to create a model consisting of 2 convolutional, one flatten, and one dense layer in Keras. This would be a model with shared weights, so without any predefined input layer.

It is possible to do using the sequential way:

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(10,3,2,'valid',activation=tf.nn.relu))
model.add(tf.keras.layers.Conv2D(20,3,2,'valid',activation=tf.nn.relu))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(200,activation=tf.nn.relu))

However, using the Functional API, produces a TypeError:

model2 = tf.keras.layers.Conv2D(10,3,2,'valid',activation=tf.nn.relu)
model2 = tf.keras.layers.Conv2D(20,3,2,'valid',activation=tf.nn.relu)(model2)
model2 = tf.keras.layers.Flatten()(model2)
model2 = tf.keras.layers.Dense(200,activation=tf.nn.relu)(model2)

Error:

TypeError: Inputs to a layer should be tensors. Got: <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fb060598100>

Is it impossible to do this way, or am I missing something?

The keras sequential api is designed to be easier to use, and as a result is less flexible than the functional api. The benefit of this is that an input 'layer' shape can be inferred automatically by whatever shape of the data you pass to it. The downside is that this easier to use model is simplified, and so you can't do things like using multiple inputs.

From the keras docs:

A Sequential model is not appropriate when:

  • Your model has multiple inputs or multiple outputs
  • Any of your layers has multiple inputs or multiple outputs
  • You need to do layer sharing
  • You want non-linear topology (eg a residual connection, a multi-branch model)

The functional api is designed to be more flexible ie multiple inputs, and so it doesn't make any sort of automatic inference for you, hence the error. You must explicitly pass an input layer in this case. For your use case, it might seem odd that it doesn't automatically infer the shape, however when you consider the wider use-case scenario it makes sense.

So the second scenario should be:

model2 = tf.keras.layers.Input((10,3,2)) # specified input layer
model2 = tf.keras.layers.Conv2D(10,3,2,'valid',activation=tf.nn.relu)(model2)
model2 = tf.keras.layers.Conv2D(20,3,2,'valid',activation=tf.nn.relu)(model2)
model2 = tf.keras.layers.Flatten()(model2)
model2 = tf.keras.layers.Dense(200,activation=tf.nn.relu)(model2)

Update

If you want to create two separate models and join them together, you should use the functional API, and then due to it's constraints you must therefore use input layers. So you could do something like:

import tensorflow as tf
from tensorflow.keras.layers import Input, Flatten, Dense, concatenate, Conv2D
from tensorflow.keras.models import Model

input1 = Input((10,3,2))
model1 = Dense(200,activation=tf.nn.relu)(input1)

input2 = Input((10,3,2))
model2 = Dense(200,activation=tf.nn.relu)(input2)

merged = concatenate([model1, model2])

merged = Conv2D(10,3,2,'valid',activation=tf.nn.relu)(merged)
merged = Flatten()(merged)
merged = Dense(200,activation=tf.nn.relu)(merged)

model = Model(inputs=[input1, input2], outputs=merged)

Above we have two separate inputs and then two Dense layers - you can build these separate lines however you want, and then to merge them together to pass them through a convolutional layer you need to use a tf.keras.layers.concatenate layer, and then you can continue the joint model from there. Wrapping the whole thing inside a Model object then allows you access training and inference methods like fit/predict etc.

The linking in keras works by propagating tensors through the layers. So in your second example, at the beginning model2 is an instance of a keras.layers.Layer and not a tf.Tensor that why you get the error.

Input creates a tensor which can then be used to link the layers. So if there is not a specific reason, you just add one:

model2 = tf.keras.layers.Input((10,3,2))
model2 = tf.keras.layers.Conv2D(10,3,2,'valid',activation=tf.nn.relu)(model2)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM