繁体   English   中英

我如何为我的 NN 模型正确地塑造我的数据?

[英]How Do I Correctly Shape my Data for my NN Model?

我正在尝试在 Keras 中创建一个基本的神经网络模型,该模型将学习将正数加在一起,但在塑造训练数据以适应模型时遇到了麻烦:

我已经为第一个 Dense 层的“input_shape”属性尝试了多种配置,但似乎没有任何效果。

# training data
training = np.array([[2,3],[4,5],[3,8],[2,9],[11,4],[13,5],[2,9]], dtype=float)
answers  = np.array([5, 9, 11, 11, 15, 18, 11],  dtype=float)
# create our model now:
model = tf.keras.Sequential([
    tf.keras.layers.Dense(units=16, input_shape=(2,)),
    tf.keras.layers.Dense(units=16, activation=tf.nn.relu),
    tf.keras.layers.Dense(units=1)
])
# set up the compile parameters: 
model.compile(loss='mean_squared_error',
              optimizer=tf.keras.optimizers.Adam(.1))
#fit the model:
model.fit(training, answers, epochs=550, verbose=False)

print(model.predict([[7,9]]))

我希望它可以正常运行并产生结果“16”,但出现以下错误:

"Traceback (most recent call last):
  File "c2f", line 27, in <module>
    print(model.predict([[7,9]]))
  File "C:\Users\Aalok\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1096, in predict
    x, check_steps=True, steps_name='steps', steps=steps)
  File "C:\Users\Aalok\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training.py", line 2382, in _standardize_user_data
    exception_prefix='input')
  File "C:\Users\Aalok\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training_utils.py", line 362, in standardize_input_data
    ' but got array with shape ' + str(data_shape))

ValueError:检查输入时出错:预期dense_input具有形状(2,)但得到形状为(1,)的数组

得到错误。 根据堆栈跟踪,错误在这一行,

print(model.predict([[7,9]]))

现在,Keras Sequential 模型需要 NumPy 数组 ( ndarray ) 形式的输入。 在上面的行中,模型将数组解释为多个输入的列表(这不是您的情况)。

根据官方文档model.fit()参数x是,

训练数据的 Numpy 数组(如果模型有单个输入),或 Numpy 数组列表(如果模型有多个输入)。 如果模型中的输入层已命名,您还可以将字典映射输入名称传递给 Numpy 数组。 如果从框架原生张量(例如 TensorFlow 数据张量)馈送,则 x 可以为 None (默认)。

我们需要创建一个 NumPy 数组,使用numpy.array()

print(model.predict(np.array([[7,9]])))

错误已解决。 我已经运行了代码,它工作得很好。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM