[英]Univariate Linear Regression outputting NaN
我目前正在用python編寫單變量線性回歸的實現:
# implementation of univariate linear regression
import numpy as np
def cost_function(hypothesis, y, m):
return (1 / (2 * m)) * ((hypothesis - y) ** 2).sum()
def hypothesis(X, theta):
return X.dot(theta)
def gradient_descent(X, y, theta, m, alpha):
for i in range(1500):
temp1 = theta[0][0] - alpha * (1 / m) * (hypothesis(X, theta) - y).sum()
temp2 = theta[1][0] - alpha * (1 / m) * ((hypothesis(X, theta) - y) * X[:, 1]).sum()
theta[0][0] = temp1
theta[1][0] = temp2
return theta
if __name__ == '__main__':
data = np.loadtxt('data.txt', delimiter=',')
y = data[:, 1]
m = y.size
X = np.ones(shape=(m, 2))
X[:, 1] = data[:, 0]
theta = np.zeros(shape=(2, 1))
alpha = 0.01
print(gradient_descent(X, y, theta, m, alpha))
到達無窮大之后,這段代碼將輸出NaN的theta-我無法弄清楚出了什么問題,但這肯定與我在漸變下降函數中更改theta有關。
我使用的數據是我在線獲得的一個簡單的線性回歸對數據集-可以正確加載。
誰能指出我正確的方向?
您看到的問題是,當您執行X[:,1]
或data[:,1]
,您會得到形狀為(m,)的對象。 當您將形狀為(m,)的對象與形狀為(m,1)的矩陣相乘時,您將得到大小為(m,m)的矩陣
a = np.array([1,2,3])
b = np.array([[4],[5],[6]])
(a*b).shape #prints (3,3)
如果在if __name__
塊中並在gradient_descent函數中執行y = y.reshape((m,1))
X_1 = X[:,1].reshape((m,1))
應該解決問題。 現在發生的是,當你做
((hypothesis(X, theta) - y) * X[:, 1])
您會得到一個100 x 100的矩陣,這不是您想要的。
我用於測試的完整代碼是:
# implementation of univariate linear regression
import numpy as np
def cost_function(hypothesis, y, m):
return (1 / (2 * m)) * ((hypothesis - y) ** 2).sum()
def hypothesis(X, theta):
return X.dot(theta)
def gradient_descent(X, y, theta, m, alpha):
X_1 = X[:,1]
X_1 = X_1.reshape((m,1))
for i in range(1500):
temp1 = theta[0][0] - alpha * (1 / m) * (hypothesis(X, theta) - y).sum()
temp2 = theta[1][0] - alpha * (1 / m) * ((hypothesis(X, theta) - y) * X_1).sum()
theta[0][0] = temp1
theta[1][0] = temp2
return theta
if __name__ == '__main__':
data= np.random.normal(size=(100,2))
y = 30*data[:,0] + data[:, 1]
m = y.size
X = np.ones(shape=(m, 2))
y = y.reshape((m,1))
X[:, 1] = data[:, 0]
theta = np.zeros(shape=(2, 1))
alpha = 0.01
print(gradient_descent(X, y, theta, m, alpha))
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.