[英]Keras MultiHeadAttention layer throwing IndexError: tuple index out of range
I'm getting this error over and over again when trying to do self attention on 1D vectors, I don't really understand why that happens, any help would be greatly appreciated.在尝试对一维向量进行自我关注时,我一遍又一遍地遇到此错误,我真的不明白为什么会发生这种情况,任何帮助将不胜感激。
layer = layers.MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.ones(shape=[1, 16])
source = tf.ones(shape=[1, 16])
output_tensor, weights = layer(target, source)
The error:错误:
~/anaconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/keras/layers/multi_head_attention.py in _masked_softmax(self, attention_scores, attention_mask)
399 attention_mask = array_ops.expand_dims(
400 attention_mask, axis=mask_expansion_axes)
--> 401 return self._softmax(attention_scores, attention_mask)
402
403 def _compute_attention(self,
~/anaconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
1010 with autocast_variable.enable_auto_cast_variables(
1011 self._compute_dtype_object):
-> 1012 outputs = call_fn(inputs, *args, **kwargs)
1013
1014 if self._activity_regularizer:
~/anaconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/keras/layers/advanced_activations.py in call(self, inputs, mask)
332 inputs, axis=self.axis, keepdims=True))
333 else:
--> 334 return K.softmax(inputs, axis=self.axis[0])
335 return K.softmax(inputs, axis=self.axis)
336
IndexError: tuple index out of range
You are forgetting the batch dimension, which is necessary.您忘记了批次维度,这是必要的。 Also if you want the output tensor and the corresponding weights, you have to set the parameter
return_attention_scores
to True
.此外,如果您想要 output 张量和相应的权重,则必须将参数
return_attention_scores
设置为True
。 Try something like this:尝试这样的事情:
import tensorflow as tf
layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
samples = 5
target = tf.ones(shape=[samples, 1, 16])
source = tf.ones(shape=[samples, 1, 16])
output_tensor, weights = layer(target, source, return_attention_scores=True)
Also according to the docs :同样根据文档:
query: Query Tensor of shape (B, T, dim)
query:查询形状的张量(B,T,dim)
value: Value Tensor of shape (B, S, dim)
value:形状的值张量(B,S,dim)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.