简体   繁体   中英

Custom gradient with complex exponential in tensorflow

As an exercise I am trying to build a custom operator in Tensorflow, and checking the gradient against Tensorflow's autodiff of the same forward operation composed of Tensorflow API operations. However, the gradient of my custom operator is incorrect. It seems like my complex analysis is not correct and needs some brushing up.

import tensorflow as tf

shape = (1, 16)
dtype = tf.complex64

x = tf.cast(tf.complex(tf.random.normal(shape), tf.random.normal(shape)), dtype=dtype)

def fun(x):
    phi = x * tf.math.conj(x)
    e = tf.exp(1j * phi)
    return e

def d_fun(x):
    d_phi = x + tf.math.conj(x)
    phi = x * tf.math.conj(x)
    d_e = 1j * d_phi * tf.exp(1j * phi)
    return d_e

@tf.custom_gradient
def tf_custom(x):    
    e = fun(x)
    def grad(dy):
        d_e = d_fun(x)
        return dy * d_e
    return e, grad

with tf.GradientTape() as g:
    g.watch(x)
    res = fun(x)
    
dy_dx = g.gradient(res, x)

with tf.GradientTape() as g:
    g.watch(x)
    res2 = tf_custom(x)
    
dy_dx2 = g.gradient(res2, x)

print(tf.reduce_sum(tf.abs(res - res2)).numpy())
print(tf.reduce_sum(tf.abs(dy_dx - dy_dx2)).numpy())

TensorFlow 2 does not directly computes the derivative of a function of complex variables. It seems that it computes the derivative of a function of a complex variable as the function of the real part and the imaginary part, using Wirtinger calculus . You can also find an explanation here .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM