简体   繁体   中英

Cannot compute simple gradient of lambda function in JAX

I'm trying to compute the gradient of a lambda function that involves other gradients of functions, but the computation is hanging and I do not understand why. In particular, the code below successfully computes f_next , but not its derivative (penultimate and last line). Any help would be appreciated

import jax
import jax.numpy as jnp

# Model parameters
γ = 1.5
k = 0.1
μY = 0.03
σ = 0.03
λ = 0.1
ωb = μY/λ

# PDE params.
σω = σ

dt =0.01

IC = lambda ω: jnp.exp(-(1-γ)*ω)

f  = [IC]

f_x= jax.grad(f[0]) #first derivative
f_xx= jax.grad(jax.grad(f[0]))#second derivative
f_old = f[0]
f_next = lambda ω: f_old(ω) + 100*dt * (
             (0.5*σω**2)*f_xx(ω) - λ*(ω-ωb)*f_x(ω) 
                - k*f_old(ω) + jnp.exp(-(1-γ)*ω))
print(f_next(0.))
f.append(f_next)

f_x= jax.grad(f[1]) #first derivative
print(f_x(0.))

It is because you're trying to define f_x using f_x in penultimate line so you are trying to compute gradient indefinitely. If you change it by:

new_f_x = jax.grad(f[1])

it will work.

By the way, even if in your case the model parameters are constants, your functions have side effects (impure) and should not be grad them at this form. Instead you should add the parameters in your functions like that:

# Model parameters
params = {'γ': 1.5,
          'k': 0.1,
          'μY': 0.03,
          'σ': 0.03,
          'λ': 0.1,
          'ωb': 0.03 / 0.1}

IC = lambda ω, params: jnp.exp(-(1-params['γ']) * ω)


def f_next(ω, params):
    γ = params['γ']
    k = params['k']
    σ = params['σ']
    λ = params['λ']
    ωb = params['ωb']

    # PDE params.
    σω = σ
    dt = 0.01

    f_x = jax.grad(IC)
    f_xx = jax.grad(jax.grad(IC))
    return f_old(ω, params) + 100*dt * (
        (0.5 * σω**2) * f_xx(ω, params) - λ * (ω-ωb) * f_x(ω, params)
        - k * f_old(ω, params) + jnp.exp(-(1-γ) * ω)
        )

f = [IC]
f_old = f[0]

print(f_next(0., params))
f.append(f_next)

new_f_x = jax.grad(f[1])
print(new_f_x(0., params))

Now you can compute the corrects gradients with other parameters with the same functions. You can even change the parameters inside f_next if needed. Note that using a dictionary of parameters as function input is very classic in Jax.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM