简体   繁体   中英

Keras simple feed-forward network input shape error

I am trying to train a very simple feed forward network in Keras. I want to give the network 1800 numbers, and have it activate 1 of 6 outputs.

My model is set up as follows:

model = keras.Sequential([
    keras.layers.Dense(256, input_dim = 1800, activation=tf.nn.relu),
    keras.layers.Dense(48, activation=tf.nn.relu),
    keras.layers.Dense(6, activation=tf.nn.softmax)
])

My data is set up as follows:

It is split into two Python lists training_data and training_labels .

An element from training_labels is a Python list containing 6 numbers like this:

[0, 0, 0, 0, 1, 0]

An element from training_data is a Python list containing 1800 numbers like this:

[15, 155, 1200, 1, ... ]

There are 1500 examples in total.

To fit the model, I am doing:

model.fit(training_data, training_labels, batch_size=1)

But I get the error:

ValueError: Error when checking input: expected dense_1_input to have shape (None, 1800) but got array with shape (150, 1)

As mentioned in the comments, you probably have a misunderstanding regarding the shape of your data. To prove that, check out the code snipped below.

import numpy as np

training_data = np.random.rand(1500, 1800)
training_labels = np.ones((1500, 6))
model = keras.Sequential([
    keras.layers.Dense(256, input_dim = 1800, activation=tf.nn.relu),
    keras.layers.Dense(48, activation=tf.nn.relu),
    keras.layers.Dense(6, activation=tf.nn.softmax)
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(training_data, training_labels, batch_size=1)

This model compiles and trains.

In addition to what have mentioned, I suggest to add one line before feeding the data into your network:

import numpy as np

training_data = np.asarray(training_data)
assert(training_data.shape = (1500,1800)) 

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM