![](/img/trans.png)
[英]RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x20 and 1x1)
[英]Error: mat1 and mat2 shapes cannot be multiplied (1000x10 and 1x1)
我正在嘗試在 pytorch 中實現嶺回歸,定義損失 function 並在不同的迭代中繪制 function。 唯一的問題是,我不斷收到錯誤代碼:mat1 和 mat2 形狀無法相乘(1000x10 和 1x1)。 我想將第二個矩陣轉換為 1x10 以完成代碼,但我似乎無法讓它工作。
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
n = 1000
p = 10
mean = np.zeros((p))
val = 0.8
cov = np.ones((p,p))*val
cov = cov + np.eye(p)*(1-val)
np.random.seed(10)
X = np.random.multivariate_normal(mean, cov, n)
theta_true = np.concatenate((np.ones((5,1)), np.zeros((5,1))),axis=0)
delta=0.5
Sigma = np.eye(n,n,k=-1)*0.4 + np.eye(n,n)*1 + np.eye(n,n,k=1)*0.4
mean = np.zeros(n)
e = np.random.multivariate_normal(mean, Sigma, 1)
y=X@theta_true + delta*e.T
import torch
X_t = torch.from_numpy(X).float()
y_t = torch.from_numpy(y).float()
Sigma_t = torch.from_numpy(Sigma).float()
import torch.nn as nn
import torch.nn.functional as F
class MyLinear(nn.Module):
def __init__(self):
super(MyLinear, self).__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
out = self.linear(x)
return out
def L2_norm(model):
return torch.sum(list(model.parameters())[0]**2)
def L1_norm(model):
return torch.sum(torch.abs(list(model.parameters())[0]))
def ridge_loss(y_pred, y_true, model, lambda_):
mse = F.mse_loss(y_pred, y_true)
regularization = lambda_ * L2_norm(model)
return mse + regularization
import matplotlib.pyplot as plt
model = MyLinear()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
lambda_ = 0.1
num_epochs = 1000
loss_values = []
for epoch in range(num_epochs):
optimizer.zero_grad()
y_pred = model(X_t)
loss = ridge_loss(y_pred, y_t, model, lambda_)
loss_values.append(loss.item())
loss.backward()
optimizer.step()
plt.plot(loss_values)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Ridge Regression Loss over Iterations')
plt.show()
我嘗試更改 theta_true 定義以轉換矩陣,但發生了同樣的錯誤。
theta_true = np.concatenate((np.ones((5,1)), np.zeros((5,1)))).reshape(10, 1)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.