简体   繁体   中英

Saving the weights of a Pytorch .pth model into a .txt or .json

I am trying to save the the weights of a pytorch model into a.txt or.json. When writing it to a.txt,

#import torch
model = torch.load("model_path")
string = str(model)
with open('some_file.txt', 'w') as fp:
     fp.write(string)

I get a file where not all the weights are saved, ie there are ellipsis throughout the textfile. I cannot write it to a JSON since the model has tensors, which are not JSON serializable [unless there is a way that I do not know?] How can I save the weights in the.pth file to some format such that no information is lost, and can be easily seen?

Thanks

When you are doing str(model.state_dict()) , it recursively uses str method of elements it contains. So the problem is how individual element string representations are build. You should increase the limit of lines printed in individual string representation:

torch.set_printoptions(profile="full")

See the difference with this:

import torch
import torchvision.models as models
mobilenet_v2 = models.mobilenet_v2()
torch.set_printoptions(profile="default")
print(mobilenet_v2.state_dict()['features.2.conv.0.0.weight'])
torch.set_printoptions(profile="full")
print(mobilenet_v2.state_dict()['features.2.conv.0.0.weight'])

Tensors are currently not JSON serializable.

A bit late but hopefully this helps. This is how you store it:

import torch
from torch.utils.data import Dataset

from json import JSONEncoder
import json

class EncodeTensor(JSONEncoder,Dataset):
    def default(self, obj):
        if isinstance(obj, torch.Tensor):
            return obj.cpu().detach().numpy().tolist()
        return super(NpEncoder, self).default(obj)

with open('torch_weights.json', 'w') as json_file:
    json.dump(model.state_dict(), json_file,cls=EncodeTensor)

Take into account that the stored values are of type list so you have to use torch.Tensor(list) when you are going to use the weights.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM