简体   繁体   中英

Keras: functional API what should the Input layer be for the embedding layer?

I am using the Keras functional API to create a neural net that takes a word embedding layer as input for a sentence classification task. But my code breaks right at the beginning of connecting the input and the embedding layers. Following a tutorial at https://medium.com/tensorflow/predicting-the-price-of-wine-with-the-keras-functional-api-and-tensorflow-a95d1c2c1b03 , I have code like below:

max_seq_length=100 #i.e., sentence has a max of 100 words 
word_weight_matrix = ... #this has a shape of 9825, 300, i.e., the vocabulary has 9825 words and each is a 300 dimension vector 
deep_inputs = Input(shape=(max_seq_length,))
embedding = Embedding(9825, 300, input_length=max_seq_length,
                          weights=word_weight_matrix, trainable=False)(deep_inputs) # line A
hidden = Dense(targets, activation="softmax")(embedding)
model = Model(inputs=deep_inputs, outputs=hidden)

Then line A causes an error that states below:

ValueError: You called `set_weights(weights)` on layer "embedding_1" with a  weight list of length 9825, but the layer was expecting 1 weights. Provided weights: [[-0.04057981  0.05743935  0.0109863  ...,  0.0072...

And I don't really understand what the error means...

It seems that the Input layer isn't defined properly... Previously when I use the Sequential model with the embedding layer defined exactly the same, everything works OK. But when I switch to functional API, I have this error.

Any help much appreciated, thanks in advance

Try this updated code: you have to use len(vocabulary) + 1 in Embedding layer! and weights=[word_weight_matrix]

max_seq_length=100 #i.e., sentence has a max of 100 words 
word_weight_matrix = ... #this has a shape of 9825, 300, i.e., the vocabulary has 9825 words and each is a 300 dimension vector 
deep_inputs = Input(shape=(max_seq_length,))
embedding = Embedding(9826, 300, input_length=max_seq_length,
                      weights=[word_weight_matrix], trainable=False)(deep_inputs) # line A
hidden = Dense(targets, activation="softmax")(embedding)
model = Model(inputs=deep_inputs, outputs=hidden)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM