繁体   English   中英

在Python中的给定点查找未知函数的梯度

[英]Finding gradient of an unknown function at a given point in Python

我被要求用python签名gradient(f, P0, gamma, epsilon)编写一个梯度下降的实现,其中f是一个未知且可能是多元函数,P0是梯度下降的起点,gamma是常数步骤和epsilon停止标准。

我发现棘手的是如何在不了解f情况下评估点P0f的梯度。 我知道有numpy.gradient但是在我不知道f的维的情况下,我不知道如何使用它。 另外, numpy.gradient可以与函数的样本一起使用,因此如何选择正确的样本来计算一个点上的梯度而又不对该函数和该点有任何信息?

我在这里假设, So how can i choose a generic set of samples each time I need to compute the gradient at a given point?So how can i choose a generic set of samples each time I need to compute the gradient at a given point? 表示该函数的尺寸是固定的,可以从您的起点推导。

考虑使用scipy的roximate_fprime这个演示,这是一种更容易使用包装器方法进行数值微分的方法,并且在需要但不给出jacobian的情况下,也用于scipy的优化器中。

当然,您不能忽略参数epsilon,它可能因数据而有所不同。

(此代码也忽略了optimize的args参数,这通常是一个好主意;我使用的事实是A和b在此处的范围内;肯定不是最佳实践)

import numpy as np
from scipy.optimize import approx_fprime, minimize
np.random.seed(1)

# Synthetic data
A = np.random.random(size=(1000, 20))
noiseless_x = np.random.random(size=20)
b = A.dot(noiseless_x) + np.random.random(size=1000) * 0.01

# Loss function
def fun(x):
    return np.linalg.norm(A.dot(x) - b, 2)

# Optimize without any explicit jacobian
x0 = np.zeros(len(noiseless_x))
res = minimize(fun, x0)
print(res.message)
print(res.fun)

# Get numerical-gradient function
eps = np.sqrt(np.finfo(float).eps)
my_gradient = lambda x: approx_fprime(x, fun, eps)

# Optimize with our gradient
res = res = minimize(fun, x0, jac=my_gradient)
print(res.message)
print(res.fun)

# Eval gradient at some point
print(my_gradient(np.ones(len(noiseless_x))))

输出:

Optimization terminated successfully.
0.09272331925776327
Optimization terminated successfully.
0.09272331925776327
[15.77418041 16.43476772 15.40369129 15.79804516 15.61699104 15.52977276
 15.60408688 16.29286766 16.13469887 16.29916573 15.57258797 15.75262356
 16.3483305  15.40844536 16.8921814  15.18487358 15.95994091 15.45903492
 16.2035532  16.68831635]

使用方法:

# Get numerical-gradient function with a way too big eps-value
eps = 1e-3
my_gradient = lambda x: approx_fprime(x, fun, eps)

表明eps是一个关键参数,导致:

Desired error not necessarily achieved due to precision loss.
0.09323354898565098

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM