简体   繁体   English

jax中的高阶多元导数

[英]Higher-order multivariate derivatives in jax

I am confused about how to compute higher-order multivariate derivatives in jax.我对如何在 jax 中计算高阶多元导数感到困惑。

For example, how do you compute d^2f / dx dy for例如,你如何计算 d^2f / dx dy

def f(x, y):
     return jnp.sin(jnp.dot(x, y.T))

where x, y in R^n, n >= 1?其中 x, y 在 R^n, n >= 1 中?

I've been experimenting with jax.jvp and jax.partial , but I haven't had any success.我一直在尝试jax.jvpjax.partial ,但没有任何成功。

Since x and y are vector-valued and f(x, y) is a scalar, I believe you can compute what you're after by combining the jax.jacfwd and jax.jacrev functions with appropriate argnums:由于xy是向量值并且f(x, y)是标量,我相信您可以通过将jax.jacfwdjax.jacrev函数与适当的 argnums 组合来计算您所追求的:

import jax.numpy as jnp
from jax import jacfwd, jacrev

def f(x, y):
     return jnp.sin(jnp.dot(x, y.T))

d2f_dxdy = jacfwd(jacrev(f, argnums=1), argnums=0)
  
x = jnp.arange(4.0)
y = jnp.ones(4)

print(d2f_dxdy(x, y))

# DeviceArray([[0.96017027, 0.        , 0.        , 0.        ],
#              [0.2794155 , 1.2395858 , 0.2794155 , 0.2794155 ],
#              [0.558831  , 0.558831  , 1.5190012 , 0.558831  ],
#              [0.83824646, 0.83824646, 0.83824646, 1.7984167 ]],
#             dtype=float32)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM