![](/img/trans.png)
[英]Tensorflow Model Input Shape Error: Input 0 of layer sequential_11 incompatible with layer: rank undefined, but the layer requires a defined rank
[英]Setting the shape of tensorflow sequential model input layer
我正在嘗試為多類分類構建模型,但我不明白如何設置正確的輸入形狀。 我有一個形狀為(5420, 212)
的訓練集,這是我構建的模型:
model = models.Sequential()
model.add(layers.Dense(64, activation='relu', input_shape = (5420,)))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(5, activation='softmax'))
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
history = model.fit(X_train, y_train, epochs=20, batch_size=512)
當我運行它時,我收到錯誤:
ValueError: Input 0 of layer sequential_9 is incompatible with the layer: expected axis -1 of input shape to have value 5420 but received input with shape (None, 212)
為什么? 輸入值不正確嗎?
輸入形狀應等於輸入X
的第二維,而輸出形狀應等於輸出Y
的第二維(假設X
和Y
都是二維的,即它們沒有更高的方面)。
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from sklearn.datasets import make_classification
from sklearn.preprocessing import OneHotEncoder
tf.random.set_seed(0)
# generate some data
X, y = make_classification(n_classes=5, n_samples=5420, n_features=212, n_informative=212, n_redundant=0, random_state=42)
print(X.shape, y.shape)
# (5420, 212) (5420,)
# one-hot encode the target
Y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1))
print(X.shape, Y.shape)
# (5420, 212) (5420, 5)
# extract the input and output shapes
input_shape = X.shape[1:] # the input shape is X's second dimension
output_shape = Y.shape[1] # the output shape is Y's second dimension
print(input_shape, output_shape)
# (212,) 5
# define the model
model = Sequential()
model.add(Dense(64, activation='relu', input_shape=input_shape))
model.add(Dense(64, activation='relu'))
model.add(Dense(output_shape, activation='softmax'))
# compile the model
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
# fit the model
history = model.fit(X, Y, epochs=3, batch_size=512)
# Epoch 1/3
# 11/11 [==============================] - 0s 1ms/step - loss: 4.8206 - accuracy: 0.2208
# Epoch 2/3
# 11/11 [==============================] - 0s 1ms/step - loss: 2.8060 - accuracy: 0.3229
# Epoch 3/3
# 11/11 [==============================] - 0s 1ms/step - loss: 2.0705 - accuracy: 0.3989
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.