简体   繁体   中英

Determine the input shape of a Keras model

I have a question about the feature_columns and the input_shape argument in tf.keras.layers.InputLayer and in Tensorflow.

I'm following an example which has the following code to create the feature columns:

feature_columns = []
latitude = tf.feature_column.numeric_column("latitude")
feature_columns.append(latitude)
longitude = tf.feature_column.numeric_column("longitude")
feature_columns.append(longitude)
fp_feature_layer = layers.DenseFeatures(feature_columns)

And below is the code to build the model:

def build_model(my_learning_rate, feature_layer)
    model = tf.keras.models.Sequential()
    model.add(feature_layer)
    model.add(tf.keras.layers.Dense(units=1, input_shape=(1,)))
    model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=my_learning_rate),
            loss="mean_squared_error",
            metrics=[tf.keras.metrics.RootMeanSquaredError()])

When calling the build_model function, I will pass in a learning rate and the feature layer, which will be the fp_feature_layer . My question is that since the feature_columns has two features in it, which is latitude and longitude, shouldn't the input_shape be (2,) instead of (1,) . Or, more generally, since the code is already specifying the feature_layer , do we still specify the input_shape in model.add(tf.keras.layers.Dense() ? Shouldn't the input_shape be determined by the feature_layer ? Is it how it works? Since the output is only one value for each example, units=1 makes sense to me. But I'm having a hard time understanding the input_shape . Thanks in advance!

Yes, your understanding of feature layer will know the input shape, we don't need to specify the input shape again in the first hidden layer is correct.

So, the code can be modified from

model.add(tf.keras.layers.Dense(units=1, input_shape=(1,)))

to

model.add(tf.keras.layers.Dense(units=1))

Please refer this Comprehensive Tensorflow Tutorial to understand how to use Feature Columns and Feature Layer in Keras Sequential Model .

Hope this helps. Happy Learning!

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM