[英]No Hidden Layer Neural Network Doesn't Equal Logistic Regression
理論上,無隱藏層神經網絡應該與邏輯回歸相同,但是,我們收集的結果差異很大。 更令人困惑的是,測試用例非常基礎,但神經網絡卻無法學習。
我們試圖選擇盡可能相似的兩個模型的參數(相同的時期數,沒有 L2 懲罰,相同的損失函數,沒有額外的優化,如動量等......)。 sklearn 邏輯回歸始終正確地找到決策邊界,變化最小。 張量流神經網絡是高度可變的,看起來偏差正在“努力”訓練。
下面包含代碼以重新創建此問題。 理想的解決方案將具有與邏輯回歸決策邊界非常相似的張量流決策邊界。
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv1D, Dense, Flatten, Input, Concatenate, Dropout
from tensorflow.keras import Sequential, Model
from tensorflow.keras.optimizers import SGD
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.linear_model import LogisticRegression
X = np.array([[1, 1],
[2, 2]])
y = np.array([0, 1])
model = LogisticRegression(penalty = 'none',
solver='sag',
max_iter = 300,
tol = 1e-100)
model.fit(X, y)
model.score(X, y)
model.coef_.flatten()[1]
model.intercept_
w_1 = model.coef_.flatten()[0]
w_2 = model.coef_.flatten()[1]
b = model.intercept_
n = np.linspace(0, 3, 10000, endpoint=False)
x_n = -w_1 / w_2 * n - b / w_2
plt.scatter(X[:, 0], X[:, 1], c = y)
plt.plot(n, x_n)
plt.gca().set_aspect('equal')
plt.show()
X = np.array([[1, 1],
[2, 2]])
y = np.array([0, 1])
optimizer = SGD(learning_rate=0.01,
momentum = 0.0,
nesterov = False,
name = 'SGD')
inputs = Input(shape = (2,), name='inputs')
outputs = Dense(1, activation = 'sigmoid', name = 'outputs')(inputs)
model = Model(inputs = inputs, outputs = outputs, name = 'model')
model.compile(loss = 'bce', optimizer = optimizer, metrics = ['AUC', 'accuracy'])
model.fit(X, y, epochs = 100, verbose=False)
print(model.evaluate(X, y))
weights, bias = model.layers[1].get_weights()
weights = weights.flatten()
w_1 = weights[0]
w_2 = weights[1]
b = bias
n = np.linspace(0, 3, 10000, endpoint=False)
x_n = -w_1 / w_2 * n - b / w_2
plt.scatter(X[:, 0], X[:, 1], c = y)
plt.plot(n, x_n)
plt.grid()
plt.gca().set_aspect('equal')
plt.show()
確定這是否真的是一個錯誤的一個簡單方法是讓你的感知器中的 epoch 數量達到任意大的數字(比如 5000)。 您會注意到決策邊界接近邏輯回歸模型的決策邊界。
自然的問題是為什么 LR 需要更少的迭代來實現接近最優的決策邊界。 對於強凸函數(如您的示例), SAG 的收斂速度比 SGD 快得多。 因此,SGD 需要更長的時間才能收斂到“全局良好”的解決方案(盡管收斂到局部良好的解決方案並不多)。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.