简体   繁体   中英

Keras simplest conv network not learning anything

I wrote simple code to learn Keras:

from tensorflow import keras


def main():
    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

    model = keras.Sequential()

    model.add(keras.layers.Conv2D(16, 3, padding='same', activation='relu'))
    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(10, activation='softmax'))

    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    model.fit(x_train, y_train, epochs=4)

    model.summary()


if __name__ == '__main__':
    main()

But it seems to not learn anything. Not like it should learn much, but should at least decrease loss and increase accuracy a little. But both are stuck the same every epoch.

I had exact same model written in Pytorch and it achieved around 35% accuracy. This in tensorflow + keras is stuck on 10%.

tensorflow-gpu v1.9

What am I missing?

I think the default learning rate is to high for this problem. Try something like

opt=keras.optimizers.Adam(lr=1.e-5)
model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', 
              metrics=['accuracy'])

I checked the default learning rate used by Adam in both keras and PyTorch , and they both use 1e-3 . Therefore, learning rate should not be the issue, assume you use default in both models.

Alternatively, I think this is related to the weight initialization, which is explicitly handled by each layer in keras but not in PyTorch .

Simply changing the training line to the following,

model.fit(x_train/255., y_train, shuffle=True, validation_data=(x_test/255., y_test), epochs=4)

you should observe both training and validation accuracy reach around 60%.

I am not familiar with PyTorch , but I suggest you initialize the weights in the keras network with those used by the PyTorch network. In this way, you will have a fair comparison.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM