簡體   English   中英

來自 keras 密集層的意外 output 形狀

[英]Unexpected output shape from a keras dense layer

我嘗試創建一個僅具有一個隱藏層的最小非卷積NN 圖像二進制分類器(作為更復雜模型之前的實踐):

def make_model(input_shape):
    inputs = keras.Input(shape=input_shape)
    x = layers.Dense(128, activation="ReLU")(inputs)
    outputs = layers.Dense(1, activation="sigmoid")(x)
    return keras.Model(inputs, outputs)
model = make_model(input_shape=(256, 256, 3))

它的model.summary()顯示

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 256, 256, 3)]     0                                                                       
 dense (Dense)               (None, 256, 256, 128)     512                                                                    
 dense_1 (Dense)             (None, 256, 256, 1)       129                                                      
=================================================================
Total params: 641
Trainable params: 641
Non-trainable params: 0

由於dense_1層只有一個神經元,我對這一層的期望是(None, 1)的output形狀(即,表示預測的二進制標簽的單個數字),而是model給出(None, 256, 256, 1) 256,256 (None, 256, 256, 1) .

我的 model 設置有什么問題,如何才能正確設置?

如果你想使用 output 形狀(None, 1)你必須壓平你荒謬的大張量:

import tensorflow as tf

def make_model(input_shape):
    inputs = tf.keras.layers.Input(shape=input_shape)
    x = tf.keras.layers.Dense(128, activation="relu")(inputs)
    x = tf.keras.layers.Flatten()(x)
    outputs = tf.keras.layers.Dense(1, activation="sigmoid")(x)
    return tf.keras.Model(inputs, outputs)

model = make_model(input_shape=(256, 256, 3))
print(model.summary())

您的 function make_model中有一個錯誤。

def make_model(input_shape):
    inputs = keras.Input(shape=input_shape)
    x = layers.Dense(128, activation="ReLU")(x)
    outputs = layers.Dense(1, activation="sigmoid")(x)
    return keras.Model(inputs, outputs)

您可能希望第二行是

x = layers.Dense(128, activation="ReLU")(inputs)

並不是

x = layers.Dense(128, activation="ReLU")(x)

不幸的是, x存在於 scope 中,所以它沒有拋出錯誤。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM