![](/img/trans.png)
[英]Tensorflow Model Input Shape Error: Input 0 of layer sequential_11 incompatible with layer: rank undefined, but the layer requires a defined rank
[英]Setting the shape of tensorflow sequential model input layer
我正在尝试为多类分类构建模型,但我不明白如何设置正确的输入形状。 我有一个形状为(5420, 212)
的训练集,这是我构建的模型:
model = models.Sequential()
model.add(layers.Dense(64, activation='relu', input_shape = (5420,)))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(5, activation='softmax'))
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
history = model.fit(X_train, y_train, epochs=20, batch_size=512)
当我运行它时,我收到错误:
ValueError: Input 0 of layer sequential_9 is incompatible with the layer: expected axis -1 of input shape to have value 5420 but received input with shape (None, 212)
为什么? 输入值不正确吗?
输入形状应等于输入X
的第二维,而输出形状应等于输出Y
的第二维(假设X
和Y
都是二维的,即它们没有更高的方面)。
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from sklearn.datasets import make_classification
from sklearn.preprocessing import OneHotEncoder
tf.random.set_seed(0)
# generate some data
X, y = make_classification(n_classes=5, n_samples=5420, n_features=212, n_informative=212, n_redundant=0, random_state=42)
print(X.shape, y.shape)
# (5420, 212) (5420,)
# one-hot encode the target
Y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1))
print(X.shape, Y.shape)
# (5420, 212) (5420, 5)
# extract the input and output shapes
input_shape = X.shape[1:] # the input shape is X's second dimension
output_shape = Y.shape[1] # the output shape is Y's second dimension
print(input_shape, output_shape)
# (212,) 5
# define the model
model = Sequential()
model.add(Dense(64, activation='relu', input_shape=input_shape))
model.add(Dense(64, activation='relu'))
model.add(Dense(output_shape, activation='softmax'))
# compile the model
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
# fit the model
history = model.fit(X, Y, epochs=3, batch_size=512)
# Epoch 1/3
# 11/11 [==============================] - 0s 1ms/step - loss: 4.8206 - accuracy: 0.2208
# Epoch 2/3
# 11/11 [==============================] - 0s 1ms/step - loss: 2.8060 - accuracy: 0.3229
# Epoch 3/3
# 11/11 [==============================] - 0s 1ms/step - loss: 2.0705 - accuracy: 0.3989
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.