簡體   English   中英

Python:無法在Keras中訓練回歸模型

[英]Python: cannot train regression model in Keras

我正在嘗試與Keras一起訓練DNN。 在此定義模型:

model = Sequential()
model.add(Dense(2050, input_shape=(2050, 75), activation='relu'))
model.add(Dense(1024, activation='relu'))
model.add(Dense(1024, activation='relu'))
model.add(Dense(1024, activation='relu'))
model.add(Dense(75, activation='sigmoid'))

成本函數是毫秒。 這里的想法是訓練一組3000張尺寸為2050 * 75的圖像,這基本上是從1025 * 75圖像中提取的兩個不同特征,以便在輸出中獲得3000張尺寸為1025 * 75的圖像,這是某種原始圖像的表示形式。

因此,輸入為(3000,2050,75)張量,而輸出尺寸為(3000,1025,75)。

我可以看到Keras為什么給我以下錯誤:

ValueError:檢查目標時出錯:預期density_5具有形狀(None,2050,75),但數組的形狀為(3000,1025,75)

必須有一種避免此錯誤的方法,可能是通過重新定義DNN尺寸或圖層。 你有什么建議嗎? 謝謝。

編輯:根據要求,這是完整的代碼。

X = train_set
Y = m
[n_samples, n_freq, n_time] = X.shape

model = Sequential()
model.add(Dense(n_freq, input_shape=(n_freq, n_time), activation='relu'))
model.add(Dense(n_hid, activation='relu'))
model.add(Dense(n_hid, activation='relu'))
model.add(Dense(n_hid, activation='relu'))
model.add(Dense(n_time, activation='sigmoid'))

model.summary()
model.compile(optimizer='rmsprop',loss='mse',metrics=['mae','accuracy'])
model.fit(np.abs(X), np.abs(Y), epochs=n_epochs, batch_size=batch_size)
score = model.evaluate(np.abs(X), np.abs(Y), batch_size = batch_size)

因為您不能在內部使用重塑層重塑陣列,因為新陣列的總大小必須保持不變。 我建議使用平整層平整張量。 但首先,您需要重塑y:

y = y.reshape(-1, 1025*75)

更新后的模型如下所示:

model = Sequential()
model.add(Dense(n_freq, input_shape=(n_freq, n_time), activation='relu'))
model.add(Dense(n_hid, activation='relu'))
model.add(Dense(n_hid, activation='relu'))
model.add(Dense(n_hid, activation='relu'))
model.add(Flatten())
model.add(Dense(1025*75, activation='sigmoid'))

model.summary()
model.compile(optimizer='rmsprop',loss='mse',metrics=['mae','accuracy'])
model.fit(np.abs(X), np.abs(y), epochs=n_epochs, batch_size=batch_size)
#score = model.evaluate(np.abs(X), np.abs(Y), batch_size = batch_size)

之后,您可以將y_pred重塑為(-1,1015,75)形狀:

y_pred = y_pred.reshape(-1,1015,75)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM