简体   繁体   中英

Feed data into lstm using tflearn python

I know there were already some questions in this area, but I couldn't find the answer to my problem. I have an LSTM (with tflearn) for a regression problem. I get 3 types of errors, no matter what kind of modifications I do.

import pandas
import tflearn
import tensorflow as tf
from sklearn.cross_validation import train_test_split


csv = pandas.read_csv('something.csv', sep = ',')

X_train, X_test = train_test_split(csv.loc[:,['x1', 'x2',
                                'x3','x4','x5','x6',
                                'x7','x8','x9',
                                'x10']].as_matrix())
Y_train, Y_test = train_test_split(csv.loc[:,['y']].as_matrix())

#create LSTM
g = tflearn.input_data(shape=[None, 1, 10])
g = tflearn.lstm(g, 512, return_seq = True)
g = tflearn.dropout(g, 0.5)
g = tflearn.lstm(g, 512)
g = tflearn.dropout(g, 0.5)
g = tflearn.fully_connected(g, 1, activation='softmax')
g = tflearn.regression(g, optimizer='adam', loss = 'mean_square',
                   learning_rate=0.001)

model = tflearn.DNN(g)
model.fit(X_train, Y_train, validation_set = (Y_train, Y_test))

n_examples = Y_train.size
def mean_squared_error(y,y_):
    return tf.reduce_sum(tf.pow(y_ - y, 2))/(2 * n_examples)

print()
print("\nTest prediction")
print(model.predict(X_test))
print(Y_test)
Y_pred = model.predict(X_test)
print('MSE Test: %.3f' % ( mean_squared_error(Y_test,Y_pred)) )

At the first run when starting new kernel i get

ValueError: Cannot feed value of shape (100, 10) for Tensor 'InputData/X:0', which has shape '(?, 1, 10)'

Then, at the second time

AssertionError: Input dim should be at least 3.

and it refers to the second LSTM layer. I tried to remove the second LSTM an Dropout layers, but then I get

 feed_dict[net_inputs[i]] = x
IndexError: list index out of range

If you read this, have a nice day. I you answer it, thanks a lot!!!!

Ok, I solved it. I post it so maybe it helps somebody:

X_train = X_train.reshape([-1,1,10])
X_test = X_test.reshape([-1,1,10])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM