[英]TensorFlow2 / Keras: input_shape seems to not have an effect when subclassing keras.Model
訪問中間層的輸出時,我總是收到錯誤消息: AttributeError: Layer l has no inbound nodes.
我讀到input_shape
在入口層確定input_shape
才能克服這個問題:
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
self.dense = tf.keras.layers.Dense(64, activation=tf.nn.relu)
self.classifier = tf.keras.layers.Dense(10, name='classifier')
def call(self, inputs):
x = self.flatten(inputs)
x = self.dense(x)
return self.classifier(x)
不幸的是,這對我不起作用。
在這次失敗的嘗試之后,我嘗試使用tf.keras.Sequential
重建我的模型,並再次為第一層指定input_shape
:
tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(64, activation=tf.nn.relu),
tf.keras.layers.Dense(10, name='classifier')
])
另一方面,這有效!
所以我問自己為什么第一種方法不起作用。 為了測試這一點,我為子類模型指定了任意/錯誤的input_shapes
,如下所示:
class WrongInputShapeModel(tf.keras.Model):
def __init__(self):
super(WrongInputShapeModel, self).__init__()
self.flatten = tf.keras.layers.Flatten(input_shape=(42, 42, 42))
# ...
我注意到這個模型仍然可以與 MNIST(具有 28x28 圖像)一起使用。 這讓我相信關鍵字input_shape
在定義模型子類化tf.keras.Model
。
這是一個錯誤,還是我錯過了什么?
從您的模型子類化代碼來看,您似乎沒有返回模型的前向傳遞輸出。 也許這就是它不起作用的原因。
這應該有效:
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.flatten = tf.keras.layers.Flatten()
self.dense = tf.keras.layers.Dense(64, activation=tf.nn.relu)
self.classifier = tf.keras.layers.Dense(10, name='classifier', activation='softmax')
def call(self, inputs):
x = self.flatten(inputs)
x = self.dense(x)
return self.classifier(x)
然后,舉個例子:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
model = MyModel()
model.compile(optimizer=tf.keras.optimizers.Adam(1e-3), loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
history = model.fit(x_train, y_train, epochs=5)
現在, Flatten
輸入形狀實際上是由層的輸入張量的形狀定義的,因此它不需要input_shape
參數。 input_shape
是Layer
一個屬性,所以你可以傳遞它,但是 layer 的 forward-pass 不使用它。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.