简体   繁体   中英

How to vectorize attention operation and avoid for-loop

I am new to Attention and I have - maybe naively - implemented one such mechanism using a python for-loop in the following code of the forward() function in my model.

Basically, I have an embedding layer for items through which I get embeddings for one item and a sequence of other items which I sum weighted by the attention weights. To get the attention weights I use a sub-network (nn.Sequencial(...)) which takes as input a pair of two item embeddings and outputs a score as in regression. All the scores are then softmaxed and used as attention weights.

def forward(self, input_features, ...):
    ...
    """ B = batch size, I = number of items for attention, E = embedding size """
    ...
    
    # get embeddings from input features for current batch
    embeddings = self.embedding_layer(input_features)         # (B, E)
    other_embeddings = self.embedding_layer(other_features)   # (I, E)

    # attention between pairs of embeddings
    attention_scores = torch.zeros((B, I))             # (B, I)
    for i in range(I):
        # repeat batch-size times for i-th embedding
        repeated_other_embedding = other_embeddings[i].view(1, -1).repeat(B, 1)   # (B, E)

        # concat pairs of embeddings to form input to attention network   
        item_emb_pairs = torch.cat((embeddings.detach(), repeated_other_embedding.detach()), dim=1)

        # pass batch through attention network
        attention_scores[:, [i]] = self.AttentionNet(item_emb_pairs)

    # pass through softmax
    attention_scores = F.softmax(attention_scores, dim=1)   # (B, I)

    ...

How do I avoid the python for-loop which I suspect is what is slowing down training so much? Can I pass a matrix of dimensions (I, B, 2*E) in self.AttentionNet() somehow?

You can use the following snippet.

embeddings = self.embedding_layer(input_features)         # (B, E)
other_embeddings = self.embedding_layer(other_features)   # (I, E)

embs = embeddings.unsqueeze(1).repeat(1, I, 1)              # (B, I, E)
other_embs = other_embeddings.unsqueeze(0).repeat(B, 1, 1)  # (B, I, E)

concatenated_embeddings = torch.cat((embs, other_embs), dim=2)  # (B, I, 2*E)

attention_scores = F.softmax(self.AttentionNet(concatenated_embeddings))    #(B, I)

You may need to make some changes in self.AttentionNet as in this scenario, you are providing the input tensors with the Batch size of B to the attention network.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM