简体   繁体   中英

Tensorflow/Keras embedding layer applied to a tensor

I would like to create an tensorflow model that takes as an input a list of integers and returns the corresponding pre-trained embeddings.

For example, if the input batch is [[1, 2, 3], [4, 5, 6]] I would like the model to return [[embed[1], embed[2], embed[3]], [embed[4], embed[5], embed[6]] , where embed is a matrix that contains pre-trained embeddings.

I think I was able to create an embedding layer with pre-trained embeddings, but my code only returns one embedding.

embedding_dim = 5
vocab_size = 100

embedding_matrix = np.random.random((vocab_size, embedding_dim))

emb_model = tf.keras.Sequential()
embedder = tf.keras.layers.Embedding(vocab_size, 
                                     embedding_dim,
                                     embeddings_initializer=tf.keras.initializers.Constant(embedding_matrix),
                                     trainable=False,
                                     input_shape=(None,))
emb_model.add(embedder)

For instance, if I do emb_model([[[8, 2, 7], [2, 8, 4]]]) only the embedding for item 8 is returned

this is the correct way to do it

embedding_dim = 5
vocab_size = 100
test = np.asarray([[1, 2, 3], 
                   [4, 5, 6]])

embedding_matrix = np.random.uniform(-1,1, (vocab_size,embedding_dim))

emb_model = tf.keras.Sequential()
embedder = tf.keras.layers.Embedding(vocab_size, 
                                     embedding_dim,
                                     trainable=False,
                                     weights=[embedding_matrix],
                                     input_shape=(None,))
emb_model.add(embedder)

emb_model(test)

we initialize a weight matrix and insert it in the model weights=[embedding_matrix] setting trainable=False

at this point, we can directly compute our predictions passing the ids of our interest

the result is an array of dim (n_batch, n_token, embedding_dim)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM