簡體   English   中英

在 Tensorflow 中實現一個扁平層

[英]Implement a flatten layer in Tensorflow

我正在嘗試使用 TensorFlow 2.2.0 實現一個扁平層。 我正在按照 Geron 的書(第 2 版)中的說明進行操作。 至於扁平層,我首先嘗試獲取批量輸入形狀並計算新形狀。 但是我在張量維度上遇到了這個問題: TypeError: Dimension value must be integer or None or have an __index__ method

import tensorflow as tf
from tensorflow import keras
(X_train, y_train), (X_test, y_test) = keras.datasets.fashion_mnist.load_data()
input_shape = X_train.shape[1:]
assert input_shape == (28, 28)

class MyFlatten(keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, batch_input_shape):
        super().build(batch_input_shape) 

    def call(self, X):
        X_shape = tf.shape(X)
        batch_size = X_shape[0]
        new_shape = tf.TensorShape([batch_size, X_shape[1]*X_shape[2]])
        return tf.reshape(X, new_shape)

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}

## works fine on this example
MyFlatten()(X_train[:10])

## fail when building a model
input_ = keras.layers.Input(shape=[28, 28])
fltten_ = MyFlatten()(input_)
hidden1 = keras.layers.Dense(300, activation="relu")(fltten_)
hidden2 = keras.layers.Dense(100, activation="relu")(hidden1)
output = keras.layers.Dense(10, activation="softmax")(hidden2)
model = keras.models.Model(inputs=[input_], outputs=[output])
model.summary()

不要嘗試創建tf.TensorShape ,它僅在已知張量的所有維度時才有效,實際上僅在急切模式下,因此 model 編譯將失敗。 像這樣簡單地重塑:

def call(self, X):
    X_shape = tf.shape(X)
    batch_size = X_shape[0]
    new_shape = [batch_size, X_shape[1] * X_shape[2]]
    return tf.reshape(X, new_shape)

或者,更一般地說,您可以這樣做:

def call(self, X):
    X_shape = tf.shape(X)
    batch_size = X_shape[0]
    new_shape = [batch_size, tf.math.reduce_prod(X_shape[1:])]
    return tf.reshape(X, new_shape)

tf.reshape也會接受類似new_shape = [batch_size, -1]的東西,但我認為這可能會使展平尺寸的大小未知,具體取決於具體情況。 另一方面,相反的東西new_shape = [-1, tf.math.reduce_prod(X_shape[1:])]也應該可以正常工作。

順便說一句,我假設您這樣做是為了練習並且已經知道這一點,但僅供參考,Keras 中已經有一個Flatten層(您可以查看它的源代碼)。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM