繁体   English   中英

PyTorch:计算 model 的 Hessian 矩阵

[英]PyTorch: Compute Hessian matrix of the model

比如说,出于某种原因,我想使用 PyTorch 拟合线性回归,如下图所示。

我如何计算 model 的Hessian 矩阵以最终计算参数估计的标准误差?

import torch 
import torch.nn as nn
# set seed 
torch.manual_seed(42)
# define the model
class OLS_pytorch(nn.Module):
    def __init__(self, X, Y):
        super(OLS_pytorch, self).__init__()
        self.X = X
        self.Y = Y
        self.beta = nn.Parameter(torch.ones(X.shape[1], 1, requires_grad=True))
        self.intercept = nn.Parameter(torch.ones(1, requires_grad=True))
        self.loss = nn.MSELoss()
        
    def forward(self):
        return self.X @ self.beta + self.intercept
    
    def fit(self, lr=0.01, epochs=1000):
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        for epoch in range(epochs):
            optimizer.zero_grad()
            loss = self.loss(self.forward(), self.Y)
            loss.backward()
            optimizer.step()
            if epoch % 10 == 0:
                print(f"Epoch {epoch} loss: {loss.item()}")
        return self

生成一些数据并使用 model

# Generate some data    
X = torch.randn(100, 1)
Y = 2 * X + 3 + torch.randn(100, 1)
# fit the model
model = OLS_pytorch(X, Y)
model.fit()
#extract parameters 
model.beta, model.intercept

#Epoch 980 loss: 0.7803605794906616
#Epoch 990 loss: 0.7803605794906616
#(Parameter containing:
# tensor([[2.0118]], requires_grad=True),
# Parameter containing:
# tensor([3.0357], requires_grad=True))

例如,在 R 中,使用相同的数据和lm() function,我恢复了相同的参数,但我也能够恢复 Hessian 矩阵,并且我能够计算标准误差。

ols <- lm(Y ~ X, data = xy)
ols$coefficients
#(Intercept)           X 
#   3.035674    2.011811 
vcov(ols)
#              (Intercept)             X
# (Intercept)  0.0079923921 -0.0004940884
# X           -0.0004940884  0.0082671053

summary(ols)
# Coefficients:
#             Estimate Std. Error t value Pr(>|t|)    
# (Intercept)  3.03567    0.08940   33.96   <2e-16 ***
# X            2.01181    0.09092   22.13   <2e-16 ***
# ---
# Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

更新:使用@cherrywoods的答案

下面是如何匹配lm()在 R 中产生的标准错误

# predict   
y_pred = model.X @ model.beta + model.intercept
sigma_hat = torch.sum((y_pred - model.Y)**2)/ (N-2) #2 is the number of estimated parameters. 

from torch.autograd.functional import hessian
def loss(beta, intercept):
    y_pred = model.X @ beta + intercept
    return model.loss(y_pred, model.Y)


H = torch.Tensor(hessian(loss, (model.beta, model.intercept)))
vcov = torch.sqrt(torch.diag(sigma_hat*torch.inverse(H/2)) )
print(vcov)
#tensor([0.9092, 0.8940], grad_fn=<SqrtBackward0>)

您可以使用torch.autograd.functional.hessian计算 Hessian。

from torch.autograd.functional import hessian

def loss(beta, intercept):
    y_pred = model.X @ beta + intercept
    return model.loss(y_pred, model.Y)

H = hessian(loss, (model.beta, model.intercept))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM