I'm trying to predict network traffic based on past values. I built an LSTM network, and tried several parameters, however I always end up with the same very low accuracy (0.108).
scaler = MinMaxScaler(feature_range = (0, 1))
dataset = scaler.fit_transform(dataset)
train_size = int(len(dataset) * 0.67)
test_size = len(dataset) - train_size
train, test = dataset[0:train_size,:], dataset[train_size:len(dataset),:]
print(len(train), len(test))
def create_dataset(dataset, window_size = 1):
data_X, data_Y = [], []
for i in range(len(dataset) - window_size - 1):
a = dataset[i:(i + window_size), 0]
data_X.append(a)
data_Y.append(dataset[i + window_size, 0])
return(np.array(data_X), np.array(data_Y))
window_size = 1
train_X, train_Y = create_dataset(train, window_size)
test_X, test_Y = create_dataset(test, window_size)
print("Original training data shape:")
print(train_X.shape)
# Reshape the input data into appropriate form for Keras.
train_X = np.reshape(train_X, (train_X.shape[0], 1, train_X.shape[1]))
test_X = np.reshape(test_X, (test_X.shape[0], 1, test_X.shape[1]))
model = Sequential()
model.add(LSTM(4, input_shape = (1, window_size)))
model.add(Dense(1))
opt = optimizers.SGD(lr=0.01, momentum=0.9)
model.compile(loss = "mean_squared_error", optimizer = opt, metrics = ['accuracy'])
As you can see my loss starts from quite a low value, and my accuracy constant over time. What am I doing wrong?
Thanks in advance. :)
You can find the loss and accuracy graph here: loss accuracy
If (13942, 1, 1)
is your entire dataset, it's far too small for deep learning; you're better off using 'shallow' methods, eg Support Vector Machines (SVM). Alternatively, consider my answer here .
EDIT : I just noticed that you use the accuracy
metric; accuracy for regression is undefined - I'm surprised an error wasn't thrown. If it uses prediction==true
to compute accuracy, you're lucky your accuracy isn't 0. Based on the loss, your model appears to be actually doing quite well; to double-check, plot predictions
vs true
and compare. (In general, mse <.3
is good, and mse <.1
is excellent)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.