简体   繁体   English

如何加载学习率调度程序 state 字典?

[英]How to load a learning rate scheduler state dict?

I have a model and a learning rate scheduler.我有一个 model 和一个学习率调度程序。 I'm saving the model and optimizer using the state dict method that is shown here .我正在使用此处显示的 state dict 方法保存 model 和优化器。

import torch
import torch.nn as nn
import torch.optim as optim

class net_x(nn.Module): 
        def __init__(self):
            super(net_x, self).__init__()
            self.fc1=nn.Linear(2, 20) 
            self.fc2=nn.Linear(20, 20)
            self.out=nn.Linear(20, 4) 

        def forward(self, x):
            x=self.fc1(x)
            x=self.fc2(x)
            x=self.out(x)
            return x

nx = net_x()


r = torch.tensor([1.0,2.0])
optimizer = optim.Adam(nx.parameters(), lr = 0.1)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-2, max_lr=0.1, step_size_up=1, mode="triangular2", cycle_momentum=False)

path = 'opt.pt'
for epoch in range(10):
    optimizer.zero_grad()
    net_predictions = nx(r)
    loss = torch.sum(torch.randint(0,10,(4,)) - net_predictions)
    loss.backward()
    optimizer.step()
    scheduler.step()
    print('loss:' , loss)
    torch.save({    'epoch': epoch,
                    'net_x_state_dict': nx.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler': scheduler,
                    }, path)

PATH = control_path
checkpoint = torch.load(path)        
nx.load_state_dict(checkpoint['net_x_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler'])

The code runs just fine without the part that I'm loading the scheduler state dict, so I'm not sure what I'm doing wrong.如果没有我正在加载调度程序 state dict 的部分,代码运行得很好,所以我不确定我做错了什么。 I'm trying to load the state dict as mentioned here , but I'm getting this error:我正在尝试加载此处提到的 state 字典,但出现此错误:

TypeError                                 Traceback (most recent call last)
<ipython-input-7-e3217d6dd870> in <module>
     42 nx.load_state_dict(checkpoint['net_x_state_dict'])
     43 optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
---> 44 scheduler.load_state_dict(checkpoint['scheduler'])

~/anaconda3/lib/python3.7/site-packages/torch/optim/lr_scheduler.py in load_state_dict(self, state_dict)
     92                 from a call to :meth:`state_dict`.
     93         """
---> 94         self.__dict__.update(state_dict)
     95 
     96     def get_last_lr(self):

TypeError: 'CyclicLR' object is not iterable

Since we have to extract the state_dict() values from scheduler before saving ie in torch.save() method因为我们必须在保存之前从scheduler中提取state_dict()值,即在torch.save()方法中
Below code will work下面的代码将工作

import torch
import torch.nn as nn
import torch.optim as optim

class net_x(nn.Module): 
        def __init__(self):
            super(net_x, self).__init__()
            self.fc1=nn.Linear(2, 20) 
            self.fc2=nn.Linear(20, 20)
            self.out=nn.Linear(20, 4) 

        def forward(self, x):
            x=self.fc1(x)
            x=self.fc2(x)
            x=self.out(x)
            return x

nx = net_x()


r = torch.tensor([1.0,2.0])
optimizer = optim.Adam(nx.parameters(), lr = 0.1)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-2, max_lr=0.1, step_size_up=1, mode="triangular2", cycle_momentum=False)

path = 'opt.pt'
for epoch in range(10):
    optimizer.zero_grad()
    net_predictions = nx(r)
    loss = torch.sum(torch.randint(0,10,(4,)) - net_predictions)
    loss.backward()
    optimizer.step()
    scheduler.step()
    print('loss:' , loss)
    torch.save({    'epoch': epoch,
                    'net_x_state_dict': nx.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),    # HERE IS THE CHANGE
                    }, path)

PATH = control_path
checkpoint = torch.load(path)        
nx.load_state_dict(checkpoint['net_x_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler'])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM