繁体   English   中英

为什么 TensorFlow Flatten 层不改变输入形状?

[英]Why is the TensorFlow Flatten layer not changing the input shape?

我是神经网络编程和使用 tensorflow 的新手,并且在最后一天尝试构建自己的一些简单网络以进行练习。 我有一个 (113, 200) 的形状,我试图在顺序网络的第一层中展平,但是当密集层运行时我收到错误

ValueError: Input 0 of layer dense is incompatible with the layer: expected axis -1 of input shape to have value 22600 but received input with shape [113, 200]

我还注意到我收到了警告

tensorflow:Model was constructed with shape (None, 113, 200) for input Tensor("flatten_input:0", shape=(None, 113, 200), dtype=float32), but it was called on an input with incompatible shape (113, 200).

但是当我将 input_shape 更改为 (None, 113, 200) 我收到其他错误

以下是我当前的代码

model_x = keras.Sequential([
    keras.layers.Flatten(input_shape=(113, 200)),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(2)
])
model_x.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer="adam",
                metrics=["accuracy"])

cp_path_x = "TestAITraining/cpx.ckpt"
cp_dir_x = os.path.dirname(cp_path_x)
cp_callback_x = keras.callbacks.ModelCheckpoint(filepath=cp_path_x, save_weights_only=True, verbose=1)
#model.load_weights(cp_path)
while True:
    model_x.fit(tf.data.Dataset.from_tensor_slices((train_arrays, train_pos)), epochs=20)

train_array 在被传递到数据集之前是一个 (1000, 113, 200) python 列表,而 train_pos 在被添加到数据集之前是一个 (1000, 2) python 列表

import tensorflow as tf
model_x = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(113, 200)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(2, activation='softmax')
])
model_x.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer="adam",
                metrics=["accuracy"])
model_x.summary()

Output

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 flatten_1 (Flatten)         (None, 22600)             0         
                                                                 
 dense_3 (Dense)             (None, 64)                1446464   
                                                                 
 dense_4 (Dense)             (None, 64)                4160      
                                                                 
 dense_5 (Dense)             (None, 2)                 130       
                                                                 
=================================================================
Total params: 1,450,754
Trainable params: 1,450,754
Non-trainable params: 0

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM