[英]Finding gradient of an unknown function at a given point in Python
我被要求用python簽名gradient(f, P0, gamma, epsilon)
編寫一個梯度下降的實現,其中f是一個未知且可能是多元函數,P0是梯度下降的起點,gamma是常數步驟和epsilon停止標准。
我發現棘手的是如何在不了解f
情況下評估點P0
處f
的梯度。 我知道有numpy.gradient
但是在我不知道f
的維的情況下,我不知道如何使用它。 另外, numpy.gradient
可以與函數的樣本一起使用,因此如何選擇正確的樣本來計算一個點上的梯度而又不對該函數和該點有任何信息?
我在這里假設, So how can i choose a generic set of samples each time I need to compute the gradient at a given point?
, So how can i choose a generic set of samples each time I need to compute the gradient at a given point?
表示該函數的尺寸是固定的,可以從您的起點推導。
考慮使用scipy的roximate_fprime這個演示,這是一種更容易使用包裝器方法進行數值微分的方法,並且在需要但不給出jacobian的情況下,也用於scipy的優化器中。
當然,您不能忽略參數epsilon,它可能因數據而有所不同。
(此代碼也忽略了optimize的args參數,這通常是一個好主意;我使用的事實是A和b在此處的范圍內;肯定不是最佳實踐)
import numpy as np
from scipy.optimize import approx_fprime, minimize
np.random.seed(1)
# Synthetic data
A = np.random.random(size=(1000, 20))
noiseless_x = np.random.random(size=20)
b = A.dot(noiseless_x) + np.random.random(size=1000) * 0.01
# Loss function
def fun(x):
return np.linalg.norm(A.dot(x) - b, 2)
# Optimize without any explicit jacobian
x0 = np.zeros(len(noiseless_x))
res = minimize(fun, x0)
print(res.message)
print(res.fun)
# Get numerical-gradient function
eps = np.sqrt(np.finfo(float).eps)
my_gradient = lambda x: approx_fprime(x, fun, eps)
# Optimize with our gradient
res = res = minimize(fun, x0, jac=my_gradient)
print(res.message)
print(res.fun)
# Eval gradient at some point
print(my_gradient(np.ones(len(noiseless_x))))
輸出:
Optimization terminated successfully.
0.09272331925776327
Optimization terminated successfully.
0.09272331925776327
[15.77418041 16.43476772 15.40369129 15.79804516 15.61699104 15.52977276
15.60408688 16.29286766 16.13469887 16.29916573 15.57258797 15.75262356
16.3483305 15.40844536 16.8921814 15.18487358 15.95994091 15.45903492
16.2035532 16.68831635]
使用方法:
# Get numerical-gradient function with a way too big eps-value
eps = 1e-3
my_gradient = lambda x: approx_fprime(x, fun, eps)
表明eps是一個關鍵參數,導致:
Desired error not necessarily achieved due to precision loss.
0.09323354898565098
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.