简体   繁体   中英

How to customize tensor operation in Tensorflow similar to tf.matmul?

I was dealing with self-attention and came across with the paper https://arxiv.org/pdf/2005.00928.pdf "Quantifying Attention Flow in Transformers"

I was trying to compute attention flow as suggested in the paper. https://samiraabnar.github.io/articles/2020-04/attention_flow

one of the author has a Github: https://github.com/samiraabnar/attention_flow That use networkx to compute attention flow. However it would be very slow when dealing with long sequences.

Long story short, I would like to utilize tensorflow and GPU accelaration to speed up the computation. However there is no straight-forward tf operations to do so.

One particular, when trying to compute maximum flow, it requires to compute something similar matrix multiplication.

I would like to have a tensor operation similar to y = tf.matmul(x1, x2) , where tf.matmul will return tensor y , where y[...,i,j] = sum(x1[...,i,:] * x2[...,:,j])

However, instead of the dot product, I would like to define a new operation such that it will replace the multiplication with maximum, ie y = some_tf_op(x1, x2) , where y[...,i,j] = sum(tf.maximum(x1[...,i,:], x2[...,:,j]))

I understand that it is doable outside of graph computation, however I wish to place it inside a graph computation (eg the call function inside a tf.keras.layers.Layer or a tf.keras.Model) without expending too much resources.

ie

import tensorflow as tf
# A is a tensor with shape (...,m,n), 
A = tf.constant([[1,2,3],[4,5,6]])

display('A',A)


# B is a tensor with shape (...,n,m)
B = tf.constant([[6,3],[5,2],[4,1]])
display('B',B)

@tf.function
def some_tf_op(x1,x2):
    ...
    return output
C = some_tf_op(A, B)
display('C',C)

expected output:

A
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)>
B
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[6, 3],
       [5, 2],
       [4, 1]], dtype=int32)>
C
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[15,  8],
       [17, 15]], dtype=int32)>

I have worked up a solution but it seems very expensive, especially dealing with large tensors.

import tensorflow as tf
# A is a tensor with shape (...,m,n), 
A = A = tf.reshape(tf.range(3*4*5), (3,4,5))
display('A',A)


# B is a tensor with shape (...,n,p)
B = tf.transpose(tf.reshape(tf.range(3*6*5,0,-1), (3,6,5)), (0,2,1))
display('B',B)

@tf.function
def some_tf_op(x1,x2):
    #...
    A = x1
    B = x2 
    B = tf.einsum('...ij->...ji',B)

    A = tf.expand_dims(A, axis = -2)
    A = tf.repeat(A, axis = -2, repeats = tf.shape(B)[-2])

    B = tf.expand_dims(B, axis = -3)
    B = tf.repeat(B, axis = -3, repeats = tf.shape(A)[-3])

    C = tf.reduce_sum(tf.math.maximum(A,B), axis = -1)
    return C
C = some_tf_op(A, B)
# C is a tensor with shape(m, p)
display('C',C)

output

<tf.Tensor: shape=(3, 4, 5), dtype=int32, numpy=
array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]], dtype=int32)>
B
<tf.Tensor: shape=(3, 5, 6), dtype=int32, numpy=
array([[[90, 85, 80, 75, 70, 65],
        [89, 84, 79, 74, 69, 64],
        [88, 83, 78, 73, 68, 63],
        [87, 82, 77, 72, 67, 62],
        [86, 81, 76, 71, 66, 61]],

       [[60, 55, 50, 45, 40, 35],
        [59, 54, 49, 44, 39, 34],
        [58, 53, 48, 43, 38, 33],
        [57, 52, 47, 42, 37, 32],
        [56, 51, 46, 41, 36, 31]],

       [[30, 25, 20, 15, 10,  5],
        [29, 24, 19, 14,  9,  4],
        [28, 23, 18, 13,  8,  3],
        [27, 22, 17, 12,  7,  2],
        [26, 21, 16, 11,  6,  1]]], dtype=int32)>
C :
<tf.Tensor: shape=(3, 4, 6), dtype=int32, numpy=
array([[[440, 415, 390, 365, 340, 315],
        [440, 415, 390, 365, 340, 315],
        [440, 415, 390, 365, 340, 315],
        [440, 415, 390, 365, 340, 315]],

       [[290, 265, 240, 215, 190, 165],
        [290, 265, 240, 215, 190, 165],
        [290, 265, 240, 215, 190, 169],
        [290, 265, 240, 215, 194, 185]],

       [[210, 210, 210, 210, 210, 210],
        [235, 235, 235, 235, 235, 235],
        [260, 260, 260, 260, 260, 260],
        [285, 285, 285, 285, 285, 285]]], dtype=int32)>

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM