[英]Higher-order multivariate derivatives in jax
我對如何在 jax 中計算高階多元導數感到困惑。
例如,你如何計算 d^2f / dx dy
def f(x, y):
return jnp.sin(jnp.dot(x, y.T))
其中 x, y 在 R^n, n >= 1 中?
我一直在嘗試jax.jvp
和jax.partial
,但沒有任何成功。
由於x
和y
是向量值並且f(x, y)
是標量,我相信您可以通過將jax.jacfwd
和jax.jacrev
函數與適當的 argnums 組合來計算您所追求的:
import jax.numpy as jnp
from jax import jacfwd, jacrev
def f(x, y):
return jnp.sin(jnp.dot(x, y.T))
d2f_dxdy = jacfwd(jacrev(f, argnums=1), argnums=0)
x = jnp.arange(4.0)
y = jnp.ones(4)
print(d2f_dxdy(x, y))
# DeviceArray([[0.96017027, 0. , 0. , 0. ],
# [0.2794155 , 1.2395858 , 0.2794155 , 0.2794155 ],
# [0.558831 , 0.558831 , 1.5190012 , 0.558831 ],
# [0.83824646, 0.83824646, 0.83824646, 1.7984167 ]],
# dtype=float32)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.