简体   繁体   中英

pytorch restore model with different batch size

I have a problem about how to reload pytorch model with different batch size. In training, my batch size is 64, but in inference, I would like the batch size is 1(feed data one by one). This is the code I used to save and restore model:

torch.save(agent.qnetwork_local.state_dict(), './ckpt/checkpoint.pth')
saved_model = QNetwork(state_size=37, action_size=4, seed=0)
saved_model.load_state_dict(torch.load('./ckpt/checkpoint.pth'))

And I got this error when running the inference model:

RuntimeError: size mismatch, m1: [37 x 1], m2: [37 x 64] at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/TH/generic/THTensorMath.cpp:2070

This error means the model's input must be 37x64, where 37 is the data dimension and 64 is the training batch size. But testing input is 37x1 which means data dimension is 37 and batch size is 1.

Is there any solution to different batch size in reload pytorch model? Thank you very much.

I eventually managed to do it using batch_size=1 in DataLoader

import torch
import pandas as pd
from torch.utils.data.dataloader import DataLoader

df = pd.read_csv('data.csv')
df = df.values

# Use CustomDataset class for your data
inference_dataset = CustomDataset(x=df[:1, 0:2])

inference_dataloader = DataLoader(inference_dataset, batch_size=1, shuffle=False, num_workers=4)

# 
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load('./model/model'))
model.eval()

for i, x in enumerate(inference_dataloader):
    x = x.float()
    y_pred = model(x)
    print(y_pred.value)

When you build up your model , you can use -1 to dynamically represent your batch size. For example , the below is the forward stage code

def forward(self, x):
     x = self.conv1(x)
     x = self.layer1(x)
     x = self.layer2(x)
     x = self.avgpool(x)
     x = x.view(-1, 37)
 #instead using x.view(64,37) 
     x = self.fc(x)

hope it can help you

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM