[英]Why is the TensorFlow Flatten layer not changing the input shape?
我是神经网络编程和使用 tensorflow 的新手,并且在最后一天尝试构建自己的一些简单网络以进行练习。 我有一个 (113, 200) 的形状,我试图在顺序网络的第一层中展平,但是当密集层运行时我收到错误
ValueError: Input 0 of layer dense is incompatible with the layer: expected axis -1 of input shape to have value 22600 but received input with shape [113, 200]
我还注意到我收到了警告
tensorflow:Model was constructed with shape (None, 113, 200) for input Tensor("flatten_input:0", shape=(None, 113, 200), dtype=float32), but it was called on an input with incompatible shape (113, 200).
但是当我将 input_shape 更改为 (None, 113, 200) 我收到其他错误
以下是我当前的代码
model_x = keras.Sequential([
keras.layers.Flatten(input_shape=(113, 200)),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(2)
])
model_x.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer="adam",
metrics=["accuracy"])
cp_path_x = "TestAITraining/cpx.ckpt"
cp_dir_x = os.path.dirname(cp_path_x)
cp_callback_x = keras.callbacks.ModelCheckpoint(filepath=cp_path_x, save_weights_only=True, verbose=1)
#model.load_weights(cp_path)
while True:
model_x.fit(tf.data.Dataset.from_tensor_slices((train_arrays, train_pos)), epochs=20)
train_array 在被传递到数据集之前是一个 (1000, 113, 200) python 列表,而 train_pos 在被添加到数据集之前是一个 (1000, 2) python 列表
import tensorflow as tf
model_x = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(113, 200)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(2, activation='softmax')
])
model_x.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer="adam",
metrics=["accuracy"])
model_x.summary()
Output
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten_1 (Flatten) (None, 22600) 0
dense_3 (Dense) (None, 64) 1446464
dense_4 (Dense) (None, 64) 4160
dense_5 (Dense) (None, 2) 130
=================================================================
Total params: 1,450,754
Trainable params: 1,450,754
Non-trainable params: 0
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.