简体   繁体   中英

Self-Attention using transformer block keras

Im trying to understand the newly implemented keras transformer class: https://keras.io/examples/nlp/text_classification_with_transformer/

I see text is first embedded and then self-attention is used. But what if I want to use another embedding than the TokenAndPositionEmbedding - eg in my case I have pre-embedded sentences and like to use self-attention on them.

What I dont understand is what the self.pos_emb does. The class TokenAndPositionEmbedding is returning x and positions , with x being the token_embedding and positions being the number of words to consider? So its basically returning two things? I dont understant that..

class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, maxlen, vocab_size, emded_dim):
        super(TokenAndPositionEmbedding, self).__init__()
        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=emded_dim)
        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=emded_dim)

    def call(self, x):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions

Or do I just feed my embedded sentences to MultiHeadSelfAttention and put a Dense-Layer after it for classification purpose?

As you know the transformer is the structure based on nothing but just lots of Dense layers with concepts of residual; however, this make the time series data losing its time dependence . So for transformer, you need to locate the position , which you can consider as the additional information for this structure so that it won't miss the time dependence. If you would like to understand it better by using keras, I will suggest the official tutorial written by Tensorflow: https://www.tensorflow.org/tutorials/text/transformer which details the things you would like to know.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM