简体   繁体   English

Tensorflow 输入上的多头注意力:4 x 5 x 20 x 64 attention_axes=2 throwing mask dimension error (tf 2.11.0)

[英]Tensorflow Multi Head Attention on Inputs: 4 x 5 x 20 x 64 with attention_axes=2 throwing mask dimension error (tf 2.11.0)

The expectation here is that the attention is applied on the 2nd dimension (4, 5, 20 , 64).这里的期望是将注意力应用于第二个维度( 4、5、20、64 )。 I am trying to apply self attention using the following code (issue reproducible with this code):我正在尝试使用以下代码应用自我关注(使用此代码可重现问题):

import numpy as np
import tensorflow as tf
from keras import layers as tfl

class Encoder(tfl.Layer):
    def __init__(self,):
        super().__init__()
        self.embed_layer = tfl.Embedding(4500, 64, mask_zero=True)
        self.attn_layer = tfl.MultiHeadAttention(num_heads=2,
                                                 attention_axes=2,
                                                 key_dim=16)
        return

    def call(self, x):
        # Input shape: (4, 5, 20) (Batch size: 4)
        x = self.embed_layer(x)  # Output: (4, 5, 20, 64)
        x = self.attn_layer(query=x, key=x, value=x)  # Output: (4, 5, 20, 64)
        return x


eg_input = tf.constant(np.random.randint(0, 150, (4, 5, 20)))
enc = Encoder()
enc(eg_input)

However, the above layer defined throws the following error.但是,上面定义的层会抛出以下错误。 Could someone please explain why is this happening & how to fix this?有人可以解释为什么会发生这种情况以及如何解决这个问题吗?

{{function_node __wrapped__AddV2_device_/job:localhost/replica:0/task:0/device:CPU:0}} Incompatible shapes: [4,5,2,20,20] vs. [4,5,1,5,20] [Op:AddV2]

Call arguments received by layer 'softmax_2' (type Softmax):
  • inputs=tf.Tensor(shape=(4, 5, 2, 20, 20), dtype=float32)
  • mask=tf.Tensor(shape=(4, 5, 1, 5, 20), dtype=bool)

PS: If I set mask_zero = False in defining the embedding layer, the code runs fine as expected without any issues. PS:如果我在定义嵌入层时设置mask_zero = False ,代码将按预期正常运行,没有任何问题。

Just concat the input along axis=0只需沿axis=0连接输入

import numpy as np
import tensorflow as tf
from keras import layers as tfl

class Encoder(tfl.Layer):
    def __init__(self,):
        super().__init__()
        self.embed_layer = tfl.Embedding(4500, 64, mask_zero=True)
        self.attn_layer = tfl.MultiHeadAttention(num_heads=2,
                                                 key_dim=16,
                                                 attention_axes=2)

    def call(self, x):
        x = self.embed_layer(x)  # Output: (4, 5, 20, 32)
        x = tf.concat(x, axis=0)
        x, attention_scores = self.attn_layer(query=x, key=x, value=x , return_attention_scores=True)  # Output: (4, 5, 20, 32)
        return x , attention_scores


eg_input = tf.constant(np.random.randint(0, 150, (4, 5, 20)))
enc = Encoder()
scores , attentions = enc(eg_input)
scores.shape , attentions.shape
#(TensorShape([4, 5, 20, 64]), TensorShape([4, 5, 2, 20, 20]))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 在 Keras 中实现注意力 r2unet 网络 - tf 2.x - implementation of attention r2unet network in Keras - tf 2.x TensorFlow 2 中 tf.contrib.seq2seq.prepare_attention 的等价物是什么 - What is the equivalence of tf.contrib.seq2seq.prepare_attention in TensorFlow 2 如何解决 pytorch 中 Multi Head Attention 的大小不匹配? - How to solve size mismatch of Multi Head Attention in pytorch? 更新 Tensorflow 变量的形状 (TF 2.x) - Update shape of a Tensorflow Variable (TF 2.x) tensorflow 中的 tf.matmul(X,weight) 与 tf.matmul(X,tf.traspose(weight)) - tf.matmul(X,weight) vs tf.matmul(X,tf.traspose(weight)) in tensorflow 使用 tf.keras.models.save_model() 保存多输入 TF 2.x 子类模型时出现 TypeError - TypeError when using tf.keras.models.save_model() to save multi-inputs TF 2.x subclass model 将 2 个 tensorflow cnn 层与不同维度的注意力 cnn 相乘时出错 - Error in multiplying 2 tensorflow cnn layers with different dimensions for attention cnn 保存 tensorflow 编码器、解码器和注意力 - Saving tensorflow encoder, decoder and attention 如何使用 tensorflow 注意层? - How to use tensorflow Attention layer? 使用 Python3.x Tensorflow2 加载数据后使用预训练的 MobileNet 时出错(尺寸错误) - Error when using a pretrained MobileNet after loading datas (dimension error) with Python3.x Tensorflow2
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM