簡體   English   中英

Keras Conv1d 輸入形狀:檢查輸入時出錯

[英]Keras Conv1d Input Shape: Error when checking input

我正在使用帶有 TF 后端的 keras 來構建一個簡單的Conv1d網絡。 數據具有以下形狀:

train feature shape: (33960, 3053, 1)
train label shape: (33960, 686, 1)

我用以下方法構建模型:

def create_conv_model():

    inp =  Input(shape=(3053, 1))
    conv = Conv1D(filters=2, kernel_size=2)(inp)
    pool = MaxPool1D(pool_size=2)(conv)
    flat = Flatten()(pool)
    dense = Dense(686)(flat)
    model = Model(inp, dense)
    model.compile(loss='mse', optimizer='adam')

    return model

型號概要:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 3053, 1)           0         
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 3052, 2)           6         
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 1526, 2)           0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 3052)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 686)               2094358   
=================================================================
Total params: 2,094,364
Trainable params: 2,094,364
Non-trainable params: 0

運行時

model.fit(x=train_feature,
    y=train_label_categorical,
    epochs=100,
    batch_size=64,
    validation_split=0.2,
    validation_data=(test_feature,test_label_categorical),
    callbacks=[tensorboard,reduce_lr,early_stopping])

我收到以下非常常見的錯誤:

ValueError: Error when checking input: expected input_1 to have 3 dimensions, but got array with shape (8491, 3053)

我已經檢查了幾乎所有關於這個非常常見問題的帖子,但我一直無法找到解決方案。 我究竟做錯了什么? 我不明白發生了什么。 形狀(8491, 3053)來自哪里?

任何幫助將不勝感激,我無法讓它消失。

model.fit函數中的validation_data=(test_feature,test_label_categorical) model.fit

validation_data=(np.expand_dims(test_feature, -1),test_label_categorical)

該模型需要形狀(8491, 3053, 1)驗證功能,但在上面的代碼中,您提供的是(8491, 3053)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM