簡體   English   中英

我如何為我的 NN 模型正確地塑造我的數據?

[英]How Do I Correctly Shape my Data for my NN Model?

我正在嘗試在 Keras 中創建一個基本的神經網絡模型,該模型將學習將正數加在一起,但在塑造訓練數據以適應模型時遇到了麻煩:

我已經為第一個 Dense 層的“input_shape”屬性嘗試了多種配置,但似乎沒有任何效果。

# training data
training = np.array([[2,3],[4,5],[3,8],[2,9],[11,4],[13,5],[2,9]], dtype=float)
answers  = np.array([5, 9, 11, 11, 15, 18, 11],  dtype=float)
# create our model now:
model = tf.keras.Sequential([
    tf.keras.layers.Dense(units=16, input_shape=(2,)),
    tf.keras.layers.Dense(units=16, activation=tf.nn.relu),
    tf.keras.layers.Dense(units=1)
])
# set up the compile parameters: 
model.compile(loss='mean_squared_error',
              optimizer=tf.keras.optimizers.Adam(.1))
#fit the model:
model.fit(training, answers, epochs=550, verbose=False)

print(model.predict([[7,9]]))

我希望它可以正常運行並產生結果“16”,但出現以下錯誤:

"Traceback (most recent call last):
  File "c2f", line 27, in <module>
    print(model.predict([[7,9]]))
  File "C:\Users\Aalok\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1096, in predict
    x, check_steps=True, steps_name='steps', steps=steps)
  File "C:\Users\Aalok\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training.py", line 2382, in _standardize_user_data
    exception_prefix='input')
  File "C:\Users\Aalok\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training_utils.py", line 362, in standardize_input_data
    ' but got array with shape ' + str(data_shape))

ValueError:檢查輸入時出錯:預期dense_input具有形狀(2,)但得到形狀為(1,)的數組

得到錯誤。 根據堆棧跟蹤,錯誤在這一行,

print(model.predict([[7,9]]))

現在,Keras Sequential 模型需要 NumPy 數組 ( ndarray ) 形式的輸入。 在上面的行中,模型將數組解釋為多個輸入的列表(這不是您的情況)。

根據官方文檔model.fit()參數x是,

訓練數據的 Numpy 數組(如果模型有單個輸入),或 Numpy 數組列表(如果模型有多個輸入)。 如果模型中的輸入層已命名,您還可以將字典映射輸入名稱傳遞給 Numpy 數組。 如果從框架原生張量(例如 TensorFlow 數據張量)饋送,則 x 可以為 None (默認)。

我們需要創建一個 NumPy 數組,使用numpy.array()

print(model.predict(np.array([[7,9]])))

錯誤已解決。 我已經運行了代碼,它工作得很好。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM