簡體   English   中英

TensorFlow2 / Keras:當子類化 keras.Model 時 input_shape 似乎沒有效果

[英]TensorFlow2 / Keras: input_shape seems to not have an effect when subclassing keras.Model

訪問中間層的輸出時,我總是收到錯誤消息: AttributeError: Layer l has no inbound nodes. 我讀到input_shape在入口層確定input_shape才能克服這個問題:

class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
        self.dense = tf.keras.layers.Dense(64, activation=tf.nn.relu)
        self.classifier = tf.keras.layers.Dense(10, name='classifier')

    def call(self, inputs):
        x = self.flatten(inputs)
        x = self.dense(x)
        return self.classifier(x)

不幸的是,這對我不起作用。

在這次失敗的嘗試之后,我嘗試使用tf.keras.Sequential重建我的模型,並再次為第一層指定input_shape

tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(64, activation=tf.nn.relu),
    tf.keras.layers.Dense(10, name='classifier')
])

另一方面,這有效!

所以我問自己為什么第一種方法不起作用。 為了測試這一點,我為子類模型指定了任意/錯誤的input_shapes ,如下所示:

class WrongInputShapeModel(tf.keras.Model):
    def __init__(self):
        super(WrongInputShapeModel, self).__init__()
        self.flatten = tf.keras.layers.Flatten(input_shape=(42, 42, 42))
        # ...

我注意到這個模型仍然可以與 MNIST(具有 28x28 圖像)一起使用。 這讓我相信關鍵字input_shape在定義模型子類化tf.keras.Model

這是一個錯誤,還是我錯過了什么?

從您的模型子類化代碼來看,您似乎沒有返回模型的前向傳遞輸出。 也許這就是它不起作用的原因。

這應該有效:

import tensorflow as tf

class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.dense = tf.keras.layers.Dense(64, activation=tf.nn.relu)
        self.classifier = tf.keras.layers.Dense(10, name='classifier', activation='softmax')

    def call(self, inputs):
        x = self.flatten(inputs)
        x = self.dense(x)
        return self.classifier(x)

然后,舉個例子:

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

model = MyModel()
model.compile(optimizer=tf.keras.optimizers.Adam(1e-3), loss=tf.keras.losses.SparseCategoricalCrossentropy(),
             metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
history = model.fit(x_train, y_train, epochs=5)

現在, Flatten輸入形狀實際上是由層的輸入張量的形狀定義的,因此它不需要input_shape參數。 input_shapeLayer一個屬性,所以你可以傳遞它,但是 layer 的 forward-pass 不使用它。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM