[英]Cannot compute simple gradient of lambda function in JAX
我正在嘗試計算涉及其他函數梯度的 lambda function 的梯度,但計算掛起,我不明白為什么。 特別是,下面的代碼成功計算f_next
,但不是它的導數(倒數第二行和最后一行)。 任何幫助,將不勝感激
import jax
import jax.numpy as jnp
# Model parameters
γ = 1.5
k = 0.1
μY = 0.03
σ = 0.03
λ = 0.1
ωb = μY/λ
# PDE params.
σω = σ
dt =0.01
IC = lambda ω: jnp.exp(-(1-γ)*ω)
f = [IC]
f_x= jax.grad(f[0]) #first derivative
f_xx= jax.grad(jax.grad(f[0]))#second derivative
f_old = f[0]
f_next = lambda ω: f_old(ω) + 100*dt * (
(0.5*σω**2)*f_xx(ω) - λ*(ω-ωb)*f_x(ω)
- k*f_old(ω) + jnp.exp(-(1-γ)*ω))
print(f_next(0.))
f.append(f_next)
f_x= jax.grad(f[1]) #first derivative
print(f_x(0.))
這是因為您試圖在倒數第二行中使用 f_x 定義 f_x,因此您試圖無限期地計算梯度。 如果您通過以下方式更改它:
new_f_x = jax.grad(f[1])
它會起作用。
順便說一句,即使在您的情況下 model 參數是常量,您的函數也有副作用(不純),不應以這種形式對它們進行分級。 相反,您應該像這樣在函數中添加參數:
# Model parameters
params = {'γ': 1.5,
'k': 0.1,
'μY': 0.03,
'σ': 0.03,
'λ': 0.1,
'ωb': 0.03 / 0.1}
IC = lambda ω, params: jnp.exp(-(1-params['γ']) * ω)
def f_next(ω, params):
γ = params['γ']
k = params['k']
σ = params['σ']
λ = params['λ']
ωb = params['ωb']
# PDE params.
σω = σ
dt = 0.01
f_x = jax.grad(IC)
f_xx = jax.grad(jax.grad(IC))
return f_old(ω, params) + 100*dt * (
(0.5 * σω**2) * f_xx(ω, params) - λ * (ω-ωb) * f_x(ω, params)
- k * f_old(ω, params) + jnp.exp(-(1-γ) * ω)
)
f = [IC]
f_old = f[0]
print(f_next(0., params))
f.append(f_next)
new_f_x = jax.grad(f[1])
print(new_f_x(0., params))
現在您可以使用具有相同功能的其他參數計算正確的梯度。 如果需要,您甚至可以更改f_next
中的參數。 請注意,使用參數字典作為 function 輸入在 Jax 中非常經典。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.