繁体   English   中英

没有隐藏层和线性激活函数的神经网络应该近似线性回归?

[英]Neural network with no hidden layers and a linear activation function should approximate a linear regression?

据我了解,假设您不使用隐藏层和线性激活函数,神经网络将产生与线性回归相同的方程形式。 即 y = SUM(w_i * x_i + b_i) 其中 i 是 0 到您拥有的功能数量。

我试图通过使用线性回归的权重和偏差向自己证明这一点,将其输入神经网络并查看结果是否相同。 他们不是。

我想知道我的理解是否不正确,或者我的代码是否正确或两者兼而有之。


from sklearn.linear_model import LinearRegression
import tensorflow as tf
from tensorflow import keras
import numpy as np

linearModel = LinearRegression()
linearModel.fit(np.array(normTrainFeaturesDf), np.array(trainLabelsDf))

# Gets the weights of the linear model and the intercept in a form that can be passed into the neural network
linearWeights = np.array(linearModel.coef_)
intercept = np.array([linearModel.intercept_])

trialWeights = np.reshape(linearWeights, (len(linearWeights), 1))
trialWeights = trialWeights.astype('float32')
intercept = intercept.astype('float32')
newTrialWeights = [trialWeights, intercept]

# Create a neural network and set the weights of the model to the linear model
nnModel = keras.Sequential([keras.layers.Dense(1, activation='linear', input_shape=[len(normTrainFeaturesDf.keys())]),])
nnModel.set_weights(newTrialWeights)

# Print predictions of both models (the results are vastly different)
print(linearModel.predict(np.array(normTestFeaturesDf))
print(nnModel.predict(normTestFeaturesDf).flatten())

是的,单层且没有激活函数的神经网络等价于线性回归。

定义一些你没有包括的变量:

normTrainFeaturesDf = np.random.rand(100, 10)
normTestFeaturesDf = np.random.rand(10, 10)
trainLabelsDf = np.random.rand(100)

然后输出如预期:

>>> linear_model_preds = linearModel.predict(np.array(normTestFeaturesDf))
>>> nn_model_preds = nnModel.predict(normTestFeaturesDf).flatten()

>>> print(linear_model_preds)
>>> print(nn_model_preds)
[0.46030349 0.69676376 0.43064266 0.4583325  0.50750268 0.51753189
 0.47254946 0.50654825 0.52998559 0.35908762]
[0.46030346 0.69676375 0.43064266 0.45833248 0.5075026  0.5175319
 0.47254944 0.50654817 0.52998555 0.3590876 ]

这些数字是相同的,除了由于浮点精度引起的小变化。

>>> np.allclose(linear_model_preds, nn_model_preds)
True

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM