简体   繁体   English

Keras - 如何为每个Input-Neuron构建共享的Embedding()层

[英]Keras - How to construct a shared Embedding() Layer for each Input-Neuron

I want to create a deep neural network in keras, where each element of the input layer is "encoded" using the same, shared Embedding()-layer, before it is fed into the deeper layers. 我想在keras中创建一个深度神经网络,其中输入层的每个元素在被馈送到更深层之前使用相同的共享嵌入()层进行“编码”。

Each input would be a number that defines the type of an object, and the network should learn an embedding that encapsulates some internal representation of "what this object is". 每个输入都是一个定义对象类型的数字,网络应该学习一个嵌入,它封装了“这个对象是什么”的内部表示。

So, if the input layer has X dimensions, and the embedding has Y dimensions, the first hidden layer should consist of X*Y neurons (each input neuron embedded). 因此,如果输入层具有X维度,并且嵌入具有Y维度,则第一隐藏层应该由X * Y神经元(每个嵌入的输入神经元)组成。

Here is a little image that should show the network architecture that I would like to create, where each input-element is encoded using a 3D-Embedding 这是一个小图像,应该显示我想要创建的网络架构,其中每个输入元素使用3D嵌入进行编码

How can I do this? 我怎样才能做到这一点?

from keras.layers import Input, Embedding

first_input = Input(shape = (your_shape_tuple) )
second_input = Input(shape = (your_shape_tuple) )
...

embedding_layer = Embedding(embedding_size)

first_input_encoded = embedding_layer(first_input)
second_input_encoded = embedding_layer(second_input)
...

Rest of the model....

The emnedding_layer will have shared weights. emnedding_layer将具有共享权重。 You can do this in form of lists of layers if you have a lot of inputs. 如果您有大量输入,则可以以图层列表的形式执行此操作。

If what you want is transforming a tensor of inputs, the way to do it is : 如果您想要的是改变输入的张量,那么这样做的方法是:

from keras.layers import Input, Embedding

# If your inputs are all fed in one numpy array :
input_layer = Input(shape = (num_input_indices,) )

# the output of this layer will be a 2D tensor of shape (num_input_indices, embedding_size)
embedded_input = Embedding(embedding_size)(input_layer)

Is this what you were looking for? 这是你在找什么?

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM