繁体   English   中英

PyTorch - AssertionError:张量之间的大小不匹配

[英]PyTorch - AssertionError: Size mismatch between tensors

我正在尝试调整为线性回归创建的 Pytorch 脚本。 它最初是为了将一组随机值(使用 np.random 创建)作为特征和目标而编写的。

我现在已经创建了一个 dataframe 的实际数据进行分析:

df = pd.read_csv('file_name.csv')

df 看起来像这样:

      X1     X2     X3     X4    X5   X6   X7  X8    Y1     Y2
0    0.98  514.5  294.0  110.25  7.0   2  0.0   0  15.55  21.33
1    0.98  514.5  294.0  110.25  7.0   3  0.0   0  15.55  21.33
2    0.98  514.5  294.0  110.25  7.0   4  0.0   0  15.55  21.33
3    0.98  514.5  294.0  110.25  7.0   5  0.0   0  15.55  21.33
4    0.90  563.5  318.5  122.50  7.0   2  0.0   0  20.84  28.28

...我目前只提取两列(X1 和 X2)作为我的特征,一列(Y1)作为我的目标,如下所示:

x = df[['X1', 'X2']]
y = df['Y1']

所以功能看起来像这样:

      X1     X2
0    0.98  514.5
1    0.98  514.5
2    0.98  514.5
3    0.98  514.5
4    0.90  563.5

和目标看起来像这样:

        Y1
0      15.55
1      15.55
2      15.55
3      15.55
4      20.84

但是,当我尝试将特征(X1 和 X1)和目标(Y1)转换为张量以便将它们提供给 NN 时,代码在以下行失败:

数据集 = TensorDataset(x_tensor_flat, y_tensor_flat)

我得到错误:

line 45, in <module> dataset = TensorDataset(x_tensor, y_tensor)
AssertionError: Size mismatch between tensors

显然有一些塑造问题在起作用,但我不知道是什么。 我试图展平和转置张量,但我得到了同样的错误。 任何帮助将不胜感激。

这是导致问题的完整代码部分:

import pandas as pd
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.utils.data.dataset import random_split

device = 'cuda' if torch.cuda.is_available() else 'cpu'


df = pd.read_csv('file_name.csv')
x = df[['X1', 'X2']]
y = df['Y1']


x_tensor = torch.from_numpy(np.array(x)).float()
y_tensor = torch.from_numpy(np.array(y)).float()


train_loader = DataLoader(dataset=train_dataset, batch_size=10)
val_loader = DataLoader(dataset=val_dataset, batch_size=10)


class ManualLinearRegression(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2, 1)

    def forward(self, x):
        return self.linear(x)

def make_train_step(model, loss_fn, optimizer):
    def train_step(x, y):
        model.train()
        yhat = model(x)
        loss = loss_fn(y, yhat)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        return loss.item()
    return train_step


torch.manual_seed(42)

model = ManualLinearRegression().to(device) 
loss_fn = nn.MSELoss(reduction='mean')
optimizer = optim.SGD(model.parameters(), lr=1e-1)
train_step = make_train_step(model, loss_fn, optimizer)

n_epochs = 50
training_losses = []
validation_losses = []
print(model.state_dict())

for epoch in range(n_epochs):
    batch_losses = []
    for x_batch, y_batch in train_loader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        loss = train_step(x_batch, y_batch)
        batch_losses.append(loss)
    training_loss = np.mean(batch_losses)
    training_losses.append(training_loss)

    with torch.no_grad():
        val_losses = []
        for x_val, y_val in val_loader:
            x_val = x_val.to(device)
            y_val = y_val.to(device)
            model.eval()
            yhat = model(x_val)
            val_loss = loss_fn(y_val, yhat).item()
            val_losses.append(val_loss)
        validation_loss = np.mean(val_losses)
        validation_losses.append(validation_loss)

    print(f"[{epoch+1}] Training loss: {training_loss:.3f}\t Validation loss: {validation_loss:.3f}")

print(model.state_dict())

问题在于您如何调用random_split function。 请注意,它将长度作为输入,而不是拆分的百分比或比率。 错误大致相同,即您指定的长度总和(80+20)与数据长度(5)不同。

下面的代码片段应该可以解决您的问题。 此外,您不需要展平张量......我认为。

dataset = TensorDataset(x_tensor, y_tensor)
val_size = int(len(dataset)*0.2)
train_size = len(dataset)- int(len(dataset)*0.2)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

这样做的问题是,没有批量大小来指定所有尺寸都是不同的,所以要解决这个问题

dataset = CustomDataset(x_tensor_flat, y_tensor_flat) # Use this should work equally well

如果您仍想使用 TensorDataset

dataset = TensorDataset(x_tensor_flat.unsqueeze(0), y_tensor_flat.unsqueeze(0)) # Make sure they have the same batch dimensions (e.g (1, 100) , (1, 20) # can be different as long as batch matches)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM