[英]Higher-order multivariate derivatives in jax
我对如何在 jax 中计算高阶多元导数感到困惑。
例如,你如何计算 d^2f / dx dy
def f(x, y):
return jnp.sin(jnp.dot(x, y.T))
其中 x, y 在 R^n, n >= 1 中?
我一直在尝试jax.jvp
和jax.partial
,但没有任何成功。
由于x
和y
是向量值并且f(x, y)
是标量,我相信您可以通过将jax.jacfwd
和jax.jacrev
函数与适当的 argnums 组合来计算您所追求的:
import jax.numpy as jnp
from jax import jacfwd, jacrev
def f(x, y):
return jnp.sin(jnp.dot(x, y.T))
d2f_dxdy = jacfwd(jacrev(f, argnums=1), argnums=0)
x = jnp.arange(4.0)
y = jnp.ones(4)
print(d2f_dxdy(x, y))
# DeviceArray([[0.96017027, 0. , 0. , 0. ],
# [0.2794155 , 1.2395858 , 0.2794155 , 0.2794155 ],
# [0.558831 , 0.558831 , 1.5190012 , 0.558831 ],
# [0.83824646, 0.83824646, 0.83824646, 1.7984167 ]],
# dtype=float32)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.