简体   繁体   中英

scipy.optimize.minimize : compute hessian and gradient together

The scipy.optimize.minimize function implements basically the equivalent to MATLAB's 'fminunc' function for finding local minima of functions.

In scipy, functions for the gradient and Hessian are separate.

res = minimize(rosen, x0, method='Newton-CG',
...                jac=rosen_der, hess=rosen_hess,
...                options={'xtol': 1e-30, 'disp': True})

However, I have a function whose Hessian and gradient share quite a few computations and I'd like to compute the Hessian and gradient together, for efficiency. In fminunc, the objective function can be written to return multiple values, ie:

function [ q, grad, Hessian ] = rosen(x)

Is there a good way to pass in a function to scipy.optimize.minimize that can compute these elements together?

You could go for a caching solution, but first numpy arrays are not hashable, and second you only need to cache a few values depending on whether the algorithm goes back and forth a lot on x . If the algorithm only moves from one point to the next, you can cache only the last computed point in this way, with your f_hes and f_jac being just lambda interfaces to a longer function computing both:

import numpy as np

# I choose the example f(x,y) = x**2 + y**2, with x,y the 1st and 2nd element of x below:
def f(x):
    return x[0]**2+x[1]**2

def f_jac_hess(x):
    if all(x==f_jac_hess.lastx):
        print('fetch cached value')
        return f_jac_hess.lastf
    print('new elaboration')
    res = array([2*x[0],2*x[1]]),array([[2,0],[0,2]])

    f_jac_hess.lastx = x
    f_jac_hess.lastf = res

    return res

f_jac_hess.lastx = np.empty((2,)) * np.nan

f_jac = lambda x : f_jac_hess(x)[0]
f_hes = lambda x : f_jac_hess(x)[1]

Now the second call would cache the saved value:

>>> f_jac([3,2])
new elaboration
Out: [6, 4]
>>> f_hes([3,2])
fetch cached value
Out: [[2, 0], [0, 2]]

You then call it as:

minimize(f,array([1,2]),method='Newton-CG',jac = f_jac, hess= f_hes, options={'xtol': 1e-30, 'disp': True})

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM