简体   繁体   中英

Use keras to build a simple ANN but get 0 accuarcy

I'm learning keras, and my task is simple. Use 3 vars from data to predict another var value, it works on R/Python ANN models with ok accuracy. But when I tried to use keras to build a neural network, it failed to work with 0 accuracy. The code is,

from keras.models import Sequential
from keras.layers import Dense, Activation
import pandas as pd
from sklearn.cross_validation import train_test_split

data = pd.read_csv(datapath)
data = data.dropna()

x = data.values[:, [2, 9, 10]]
y = data.values[:, 8]
train_X, test_X, train_y, test_y = train_test_split(x, y, train_size=0.5, random_state=0)

model = Sequential()
model.add(Dense(16, input_shape=(3,)))
model.add(Activation('sigmoid'))
model.add(Dense(1))
model.add(Activation('linear'))

model.compile(optimizer='sgd', loss='mean_squared_error', metrics=['accuracy'])

model.fit(train_X, train_y, nb_epoch=100, batch_size=32, verbose=0)
loss, accuracy = model.evaluate(test_X, test_y, verbose=0)
print("Accuracy = {:.2f}".format(accuracy))

So, my question is, is my code right and How to make it work?

If you try to predict a value which is real number, you might not be interested in the accuracy which is dedicated for classification. I would suggest to measure the training progress with the loss you are using : mean square error.

If you want to monitor training progress, you can pass a keyword validation_data=[X_val, y_val] to the fit method.

Otherwise, to measure the prediction error on the testing set just use the scikit-learn function http://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html

Update One difference with other packages might be the initialization of the weights. Playing with the initialization parameter of Dense layers or the scale of initialization might help you.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM