簡體   English   中英

jax中的高階多元導數

[英]Higher-order multivariate derivatives in jax

我對如何在 jax 中計算高階多元導數感到困惑。

例如,你如何計算 d^2f / dx dy

def f(x, y):
     return jnp.sin(jnp.dot(x, y.T))

其中 x, y 在 R^n, n >= 1 中?

我一直在嘗試jax.jvpjax.partial ,但沒有任何成功。

由於xy是向量值並且f(x, y)是標量,我相信您可以通過將jax.jacfwdjax.jacrev函數與適當的 argnums 組合來計算您所追求的:

import jax.numpy as jnp
from jax import jacfwd, jacrev

def f(x, y):
     return jnp.sin(jnp.dot(x, y.T))

d2f_dxdy = jacfwd(jacrev(f, argnums=1), argnums=0)
  
x = jnp.arange(4.0)
y = jnp.ones(4)

print(d2f_dxdy(x, y))

# DeviceArray([[0.96017027, 0.        , 0.        , 0.        ],
#              [0.2794155 , 1.2395858 , 0.2794155 , 0.2794155 ],
#              [0.558831  , 0.558831  , 1.5190012 , 0.558831  ],
#              [0.83824646, 0.83824646, 0.83824646, 1.7984167 ]],
#             dtype=float32)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM