简体   繁体   English

如何以与CNN中使用的类似格式在.npy文件中保存权重?

[英]How to save weights in .npy file in a similar format which is being used in CNN?

I am using a github repository containing a trained CNN with weight parameters given in .npy file. 我正在使用一个github存储库,其中包含一个训练有素的CNN,其重量参数在.npy文件中给出。 Model is loading the weights and using the model parameters like this:- 模型正在加载权重并使用如下模型参数: -

model = CNN_Model(batch_size)
filename = "weight_file.npy"
dtype = torch.FloatTensor    
model.load_state_dict(load_weights(model, weight_file, dtype))

And load_weights is defined as:- load_weights定义为: -

def load_weights(model, filename, dtype):
    model_params = model.state_dict()
    data_dict = np.load(filename, encoding='latin1').item()
    model_params["conv1.weight"] = torch.from_numpy(data_dict["conv1"] ["weights"]).type(dtype).permute(3,2,0,1)
    model_params["conv1.bias"] = torch.from_numpy(data_dict["conv1"]["biases"]).type(dtype)
    model_params["bn1.weight"] = torch.from_numpy(data_dict["bn_conv1"]["scale"]).type(dtype)
    model_params["bn1.bias"] = torch.from_numpy(data_dict["bn_conv1"]["offset"]).type(dtype)
    return model_params

I have added a training module to it and trying to fine tune the weights on my own dataset. 我已经为它添加了一个训练模块,并尝试在我自己的数据集上微调权重。 After training i want to save new weights in .npy file with same indexes of data_dict as there were in previously loaded weight file so i can use them again for CNN model. 训练结束后我要保存新权重.npy与同一指标文件data_dict因为有前面装载重量文件,这样我就可以再次使用他们的CNN模型。

How should i do indexing with similar names before saving the data_dict array using: 在使用以下方法保存data_dict数组之前,我应该如何使用类似的名称进行索引:

np.save("trained_weight_file.npy", data_dict)

EDIT 1:- So on recommendation of @ ad i did 编辑1: - 所以在推荐@ad我做了

data_dict = model.state_dict()

What it did is it saved all the weights with index of model_params . 它的作用是用model_params索引保存所有权重。 Output of print data_dict was:- print data_dict输出是: -

OrderedDict([('conv1.weight', tensor([[[[....]]]])), ('conv1.bias', tensor([....])), , ('bn1.weight', tensor([....])), ('bn1.bias', tensor([....]))])

But what i need is to store in data_dict index so i can read it with same algo from .npy file. 但我需要的是存储在data_dict索引中,以便我可以使用.npy文件中的相同算法来读取它。 Also i tried returning data_dict along with model_params from load_weights definition and then tried to use data_dict = model.state_dict() but it gave me error on `model.load_state_dict(load_weights(model, weight_file, dtype))' line that is:- 此外,我尝试从load_weights定义返回data_dictmodel_params ,然后尝试使用data_dict = model.state_dict()但它给了我关于`model.load_state_dict(load_weights(model,weight_file,dtype))'行的错误: -

Traceback (most recent call last): model.load_state_dict(load_weights(model, weight_file, dtype)) state_dict = state_dict.copy() AttributeError: 'tuple' object has no attribute 'copy' 回溯(最近调用最后一次):model.load_state_dict(load_weights(model,weight_file,dtype))state_dict = state_dict.copy()AttributeError:'tuple'对象没有属性'copy'

I would do something like data_dict = model.state_dict() . 我会做类似data_dict = model.state_dict()事情。

You can read the official documentation with an example of the output of state_dict() here . 您可以在此处阅读官方文档,其中包含state_dict()输出的state_dict() There is a github repository that is the base of the github repository from which you might have your code. 有一个github存储库github存储库的基础,您可以从中获取代码。 This repository uses model.state_dict() to store the values as well. 此存储库还使用model.state_dict()来存储值。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM