简体   繁体   中英

Implement a flatten layer in Tensorflow

I am trying to implement a flatten layer using TensorFlow 2.2.0. I am following the instructions from Geron's book (2nd ed). As for a flatten layer, I first try to get the batch input shape and compute the new shape. But I have met this problem with tensor dimensions: TypeError: Dimension value must be integer or None or have an __index__ method

import tensorflow as tf
from tensorflow import keras
(X_train, y_train), (X_test, y_test) = keras.datasets.fashion_mnist.load_data()
input_shape = X_train.shape[1:]
assert input_shape == (28, 28)

class MyFlatten(keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, batch_input_shape):
        super().build(batch_input_shape) 

    def call(self, X):
        X_shape = tf.shape(X)
        batch_size = X_shape[0]
        new_shape = tf.TensorShape([batch_size, X_shape[1]*X_shape[2]])
        return tf.reshape(X, new_shape)

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}

## works fine on this example
MyFlatten()(X_train[:10])

## fail when building a model
input_ = keras.layers.Input(shape=[28, 28])
fltten_ = MyFlatten()(input_)
hidden1 = keras.layers.Dense(300, activation="relu")(fltten_)
hidden2 = keras.layers.Dense(100, activation="relu")(hidden1)
output = keras.layers.Dense(10, activation="softmax")(hidden2)
model = keras.models.Model(inputs=[input_], outputs=[output])
model.summary()

Don't try to create atf.TensorShape , that will only work when all dimensions of the tensor are known, which in practice will be only on eager mode, so it model compilation will fail. Simply reshape like this:

def call(self, X):
    X_shape = tf.shape(X)
    batch_size = X_shape[0]
    new_shape = [batch_size, X_shape[1] * X_shape[2]]
    return tf.reshape(X, new_shape)

Or, more generally, you can just do:

def call(self, X):
    X_shape = tf.shape(X)
    batch_size = X_shape[0]
    new_shape = [batch_size, tf.math.reduce_prod(X_shape[1:])]
    return tf.reshape(X, new_shape)

tf.reshape would also accept something like new_shape = [batch_size, -1] , but I think that may make the size of the flattened dimension unknown depending on the case. On the other hand, the opposite thing, new_shape = [-1, tf.math.reduce_prod(X_shape[1:])] , should work fine too.

Btw, I assume you are doing this as an exercise and already know this, but just for reference there is already a Flatten layer in Keras (and you can check its source code ).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM