简体   繁体   English

如何将 model.state_dict() 存储在临时变量中以备后用?

[英]how to store model.state_dict() in a temp var for later use?

I tried to store the state dict of my model in a variable temporarily and wanted to restore it to my model later, but the content of this variable changed automatically as the model updated. I tried to store the state dict of my model in a variable temporarily and wanted to restore it to my model later, but the content of this variable changed automatically as the model updated.

There is a minimal example:有一个最小的例子:

import torch as t
import torch.nn as nn
from torch.optim import Adam


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(3, 2)

    def forward(self, x):
        return self.fc(x)


net = Net()
loss_fc = nn.MSELoss()
optimizer = Adam(net.parameters())

weights = net.state_dict()
print(weights)

x = t.rand((5, 3))
y = t.rand((5, 2))
loss = loss_fc(net(x), y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(weights)

I thought the two outputs would be the same, but I got (outputs may change due to random initialization)我认为这两个输出是相同的,但我得到了(输出可能会因随机初始化而改变)

OrderedDict([('fc.weight', tensor([[-0.5557,  0.0544, -0.2277],
        [-0.0793,  0.4334, -0.1548]])), ('fc.bias', tensor([-0.2204,  0.2846]))])
OrderedDict([('fc.weight', tensor([[-0.5547,  0.0554, -0.2267],
        [-0.0783,  0.4344, -0.1538]])), ('fc.bias', tensor([-0.2194,  0.2856]))])

The content of weights changed, which is so weird. weights的内容发生了变化,这太奇怪了。

I also tried .copy() and t.no_grad() as following, but they did not help.我还尝试.copy()t.no_grad()如下,但它们没有帮助。

with t.no_grad():
    weights = net.state_dict().copy()

Yes, I know that I can save state dict using t.save() , but I just want to figure out what happened in the previous example.是的,我知道我可以使用t.save()保存 state 字典,但我只想弄清楚前面的示例中发生了什么。

I'm using Python 3.8.5 and Pytorch 1.8.1我正在使用Python 3.8.5Pytorch 1.8.1

Thanks for any help.谢谢你的帮助。

That's how OrderedDict works.这就是OrderedDict的工作原理。 Here's a simpler example:这是一个更简单的例子:

from collections import OrderedDict

# a mutable variable
l = [1,2,3]

# an OrderedDict with an entry pointing to that mutable variable
x = OrderedDict([("a", l)])

# if you change the list
l[1] = 20

# the change is reflected in the OrderedDict
print(x)
# >> OrderedDict([('a', [1, 20, 3])])

If you want to avoid that, you'll have to do a deepcopy rather than a shallow copy :如果你想避免这种情况,你必须做一个深deepcopy而不是浅copy

from copy import deepcopy
x2 = deepcopy(x)

print(x2)
# >> OrderedDict([('a', [1, 20, 3])])

# now, if you change the list
l[2] = 30

# you do not change your copy
print(x2)
# >> OrderedDict([('a', [1, 20, 3])])

# but you keep changing the original dict
print(x)
# >> OrderedDict([('a', [1, 20, 30])])

As Tensor is also mutable, the same behaviour is expected in your case.由于Tensor也是可变的,因此在您的情况下预期会有相同的行为。 Therefore, you can use:因此,您可以使用:

from copy import deepcopy

weights = deepcopy(net.state_dict())

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM