簡體   English   中英

Keras Train神經網絡維度值錯誤:預期具有2個維度,但數組的形狀為(32,1,4)

[英]Keras Train Neural Network Dimension Value Error: expected to have 2 dimensions, but got array with shape (32, 1, 4)

Python 3.6
Keras 2.2
Tensorflow 1.8 backend

我在訓練我的神經網絡時遇到了麻煩,因為出現了以下錯誤:

ValueError: Error when checking target: expected t_dense_3 to have 2 dimensions, but got array with shape (32, 1, 4)

我的神經網絡

>>> sgd = optimizers.SGD(lr=0.01, decay=1e-6)
>>> target_q_network = Sequential([
      Dense(40, input_shape=observation_shape, activation='relu', name='t_dense_1'),
      Dense(40, activation='relu', name='t_dense_2'),
      Dense(number_of_actions, activation='linear', name='t_dense_3')
    ])
>>> target_q_network.compile(loss='mean_squared_error', optimizer=sgd)
>>> observation_shape
    (8,)

-----------------------------------------------------------------

(Pdb) target_q_network.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
t_dense_1 (Dense)            (None, 40)                360       
_________________________________________________________________
t_dense_2 (Dense)            (None, 40)                1640      
_________________________________________________________________
t_dense_3 (Dense)            (None, 4)                 164       
=================================================================
Total params: 2,164
Trainable params: 2,164
Non-trainable params: 0
_________________________________________________________________

當我將值傳遞給神經網絡時,將返回形狀(1、4)的數組:

(Pdb) env.reset()
array([-0.00126171,  0.94592496, -0.12780861,  0.35410735,  0.00146875, 0.02895054,  0.        ,  0.        ])
# Passing value into Neural Network
(Pdb) target_q_network.predict(env.reset().reshape(1,8))
array([[ 0.07440183,  0.03480911,  0.11266299, -0.08043154]], dtype=float32)

我正在傳遞training_setlabels

(Pdb) training_set.shape
(32, 8)
(Pdb) labels.shape
(32, 1, 4)

'mean_squared_error'損失函數可能期望接收(batch_sz x n_labels)標簽矩陣,但是您要傳遞(batch_sz x 1 x n_labels)標簽矩陣,尤其是使用labels.shape=(32, 1, 4) 您只需要調整labels的形狀以使其具有形狀(batch_sz x n_labels) ,使其具有labels.shape=(32, 4) ,然后可以將其與神經網絡輸出進行適當比較。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM