簡體   English   中英

無法在 JAX 中計算 lambda function 的簡單梯度

[英]Cannot compute simple gradient of lambda function in JAX

我正在嘗試計算涉及其他函數梯度的 lambda function 的梯度,但計算掛起,我不明白為什么。 特別是,下面的代碼成功計算f_next ,但不是它的導數(倒數第二行和最后一行)。 任何幫助,將不勝感激

import jax
import jax.numpy as jnp

# Model parameters
γ = 1.5
k = 0.1
μY = 0.03
σ = 0.03
λ = 0.1
ωb = μY/λ

# PDE params.
σω = σ

dt =0.01

IC = lambda ω: jnp.exp(-(1-γ)*ω)

f  = [IC]

f_x= jax.grad(f[0]) #first derivative
f_xx= jax.grad(jax.grad(f[0]))#second derivative
f_old = f[0]
f_next = lambda ω: f_old(ω) + 100*dt * (
             (0.5*σω**2)*f_xx(ω) - λ*(ω-ωb)*f_x(ω) 
                - k*f_old(ω) + jnp.exp(-(1-γ)*ω))
print(f_next(0.))
f.append(f_next)

f_x= jax.grad(f[1]) #first derivative
print(f_x(0.))

這是因為您試圖在倒數第二行中使用 f_x 定義 f_x,因此您試圖無限期地計算梯度。 如果您通過以下方式更改它:

new_f_x = jax.grad(f[1])

它會起作用。

順便說一句,即使在您的情況下 model 參數是常量,您的函數也有副作用(不純),不應以這種形式對它們進行分級。 相反,您應該像這樣在函數中添加參數:

# Model parameters
params = {'γ': 1.5,
          'k': 0.1,
          'μY': 0.03,
          'σ': 0.03,
          'λ': 0.1,
          'ωb': 0.03 / 0.1}

IC = lambda ω, params: jnp.exp(-(1-params['γ']) * ω)


def f_next(ω, params):
    γ = params['γ']
    k = params['k']
    σ = params['σ']
    λ = params['λ']
    ωb = params['ωb']

    # PDE params.
    σω = σ
    dt = 0.01

    f_x = jax.grad(IC)
    f_xx = jax.grad(jax.grad(IC))
    return f_old(ω, params) + 100*dt * (
        (0.5 * σω**2) * f_xx(ω, params) - λ * (ω-ωb) * f_x(ω, params)
        - k * f_old(ω, params) + jnp.exp(-(1-γ) * ω)
        )

f = [IC]
f_old = f[0]

print(f_next(0., params))
f.append(f_next)

new_f_x = jax.grad(f[1])
print(new_f_x(0., params))

現在您可以使用具有相同功能的其他參數計算正確的梯度。 如果需要,您甚至可以更改f_next中的參數。 請注意,使用參數字典作為 function 輸入在 Jax 中非常經典。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM