简体   繁体   中英

Find gradient of a function: Sympy vs. Jax

I have a function Black_Cox() which calls other functions as shown below:

import numpy as np
from scipy import stats

# Parameters
D = 100
r = 0.05
γ = 0.1

# Normal CDF
N = lambda x: stats.norm.cdf(x)

H = lambda V, T, L, σ: np.exp(-r*T) * N( (np.log(V/L) + (r-0.5*σ**2)*T) / (σ*np.sqrt(T)) )

# Black-Scholes
def C_BS(V, K, T, σ):
    d1 = (np.log(V/K) + (r + 0.5*σ**2)*T ) / ( σ*np.sqrt(T) )
    d2 = d1 - σ*np.sqrt(T)
    return V*N(d1) - np.exp(-r*T)*K*N(d2)

def BL(V, T, D, L, σ):
    return L * H(V, T, L, σ) - L * (L/V)**(2*r/σ**2-1) * H(L**2/V, T, L, σ) \
              + C_BS(V, L, T, σ) - (L/V)**(2*r/σ**2-1) * C_BS(L**2/V, L, T, σ) \
              - C_BS(V, D, T, σ) + (L/V)**(2*r/σ**2-1) * C_BS(L**2/V, D, T, σ)

def Bb(V, T, C, γ, σ, a):
    b = (np.log(C/V) - γ*T) / σ
    μ = (r - a - 0.5*σ**2 - γ) / σ
    m = np.sqrt(μ**2 + 2*r)
    return C*np.exp(b*(μ-m)) * ( N((b-m*T)/np.sqrt(T)) + np.exp(2*m*b)*N((b+m*T)/np.sqrt(T)) )

def Black_Cox(V, T, C=160, σ=0.1, a=0):
    return np.exp(γ*T)*BL(V*np.exp(-γ*T), T, D*np.exp(-γ*T), C*np.exp(-γ*T), σ) + Bb(V, T, C, γ, σ, a)

I need to work with the derivative of the Black_Cox function w.r.t. V . More precisely, I need to evaluate this derivative across thousands of paths where I change other arguments, find the derivative and evaluate at some V .

What is the best way to proceed?

  • Should I use sympy to find this derivative and then evaluate at my V of choice, as I would do in Mathematica: D[BlackCox[V, 10, 100, 160], V] /. V -> 180 D[BlackCox[V, 10, 100, 160], V] /. V -> 180 , or

  • Should I just use jax ?

If sympy , how would you advise me to do this?

With jax I understand that I need to do the following imports:

import jax.numpy as np
from jax.scipy import stats
from jax import grad

and re-evaluate my functions before getting the gradient:

func = lambda x: Black_Cox(x,10,160,0.1)
grad(func)(180.0)

If I will still need to work with the numpy version of the functions, will I have to create 2 instances of each function(s) or is there an elegant way to duplicate a function for jax purposes?

Jax does not provide any built-in way to recompile a numpy function using jax versions of numpy and scipy. But you can use a snippet like the following one to do it automatically:

import inspect
from functools import wraps
import numpy as np
import jax.numpy

def replace_globals(func, globals_):
  """Recompile a function with replaced global values."""
  namespace = func.__globals__.copy()
  namespace.update(globals_)
  source = inspect.getsource(func)
  exec(source, namespace)
  return wraps(func)(namespace[func.__name__])

It works like this:

def numpy_func(N):
  return np.arange(N) ** 2

jax_func = replace_globals(numpy_func, {"np": jax.numpy})

Now you can evaluate the numpy version:

numpy_func(10)
# array([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81])

and the jax version:

jax_func(10)
# DeviceArray([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81], dtype=int32)

Just make certain you replace all the relevant global variables when you wrap your more complicated function.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM