简体   繁体   中英

Keras - LSTM using Tensorflow Backend

I am trying this sample code from here . Earlier I had problems using Keras - Theano on Windows now just migrated to Keras-Tensorflow on Ubuntu. The versions are Keras(2.0.2) and TF (0.12.1).

keras.__version__        # 2.0.2
tensorflow.__version__   # 0.12.1

import numpy
from keras.datasets import imdb
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers.embeddings import Embedding
from keras.preprocessing import sequence
# fix random seed for reproducibility
numpy.random.seed(7)
# load the dataset but only keep the top n words, zero the rest
top_words = 5000

(X_train, y_train), (X_test, y_test) = imdb.load_data(nb_words=top_words)
# truncate and pad input sequences

max_review_length = 500

X_train = sequence.pad_sequences(X_train, maxlen=max_review_length)
X_test = sequence.pad_sequences(X_test, maxlen=max_review_length)
# create the model

embedding_vecor_length = 32
model = Sequential()
model.add(Embedding(top_words, embedding_vecor_length, input_length=max_review_length))
model.add(LSTM(100))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
print(model.summary())

model.fit(X_train, y_train, nb_epoch=3, batch_size=64)
# Final evaluation of the model

scores = model.evaluate(X_test, y_test, verbose=0)
print("Accuracy: %.2f%%" % (scores[1]*100))

The error is as follows and any help on solving the same would be highly appreciated

  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework
  /tensor_util.py", line 302, in _AssertCompatible
  (dtype.name, repr(mismatch), type(mismatch).__name__))


**TypeError: Expected int32, got list containing Tensors of type '_Message' instead.**

Your code is running fine on newer version of Tensorflow ( I checked it on Tensorflow 1.0 and Keras 2.0.2). You are using an old version of Tensorflow. Please update you Tensorflow and it should work perfectly. The current available version of Tensorflow is 1.1

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM