简体   繁体   English

pytorch恢复模型具有不同的批量大小

[英]pytorch restore model with different batch size

I have a problem about how to reload pytorch model with different batch size. 我有一个关于如何重新加载不同批量大小的pytorch模型的问题。 In training, my batch size is 64, but in inference, I would like the batch size is 1(feed data one by one). 在培训中,我的批量大小是64,但在推断中,我希望批量大小为1(逐个提供数据)。 This is the code I used to save and restore model: 这是我用来保存和恢复模型的代码:

torch.save(agent.qnetwork_local.state_dict(), './ckpt/checkpoint.pth')
saved_model = QNetwork(state_size=37, action_size=4, seed=0)
saved_model.load_state_dict(torch.load('./ckpt/checkpoint.pth'))

And I got this error when running the inference model: 运行推理模型时出现此错误:

RuntimeError: size mismatch, m1: [37 x 1], m2: [37 x 64] at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/TH/generic/THTensorMath.cpp:2070

This error means the model's input must be 37x64, where 37 is the data dimension and 64 is the training batch size. 此错误表示模型的输入必须为37x64,其中37是数据维度,64是训练批量大小。 But testing input is 37x1 which means data dimension is 37 and batch size is 1. 但测试输入为37x1,这意味着数据维度为37,批量大小为1。

Is there any solution to different batch size in reload pytorch model? 在重装pytorch模型中有不同批量大小的解决方案吗? Thank you very much. 非常感谢你。

I eventually managed to do it using batch_size=1 in DataLoader 我最终设法在DataLoader中使用batch_size=1来完成它

import torch
import pandas as pd
from torch.utils.data.dataloader import DataLoader

df = pd.read_csv('data.csv')
df = df.values

# Use CustomDataset class for your data
inference_dataset = CustomDataset(x=df[:1, 0:2])

inference_dataloader = DataLoader(inference_dataset, batch_size=1, shuffle=False, num_workers=4)

# 
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load('./model/model'))
model.eval()

for i, x in enumerate(inference_dataloader):
    x = x.float()
    y_pred = model(x)
    print(y_pred.value)

When you build up your model , you can use -1 to dynamically represent your batch size. 构建模型时,可以使用-1动态表示批量大小。 For example , the below is the forward stage code 例如,以下是前台代码

def forward(self, x):
     x = self.conv1(x)
     x = self.layer1(x)
     x = self.layer2(x)
     x = self.avgpool(x)
     x = x.view(-1, 37)
 #instead using x.view(64,37) 
     x = self.fc(x)

hope it can help you 希望它可以帮到你

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM