繁体   English   中英

pytorch:GRU 无法就地更新 hidden_state

[英]pytorch: GRU cannot update hidden_state in-place

使用 pytorch 实现 GRU 网络时遇到问题:

我的代码如下:

import torch

class GRU_model(torch.nn.Module):
    def __init__(self, device):
        super(GRU_model, self).__init__()
        self.h = torch.randn((1,1,5), device=device, dtype=torch.float)

        self.GRU_1 = torch.nn.GRU(input_size=5, hidden_size=5)

    def forward(self, a):
        output, self.h = self.GRU_1(a, self.h)
        return output

if __name__ == '__main__':
    learn_rate=1e-4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GRU_model(device).to(device=device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)

    for i in range(10):
        a = torch.randn((1, 1, 5), device=device, dtype=torch.float)
        output = model(a)
        loss = (a - output).mean()

        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()

我收到这样的错误:

Traceback (most recent call last):
  File "C:/Users/Administrator_/Desktop/Graduation_Project/MIDI_Music_style_transfer/GRU_toy_in-place_hidden_states_change/main.py", line 40, in <module>
    loss.backward(retain_graph=True)
  File "C:\Users\Administrator_\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\tensor.py", line 245, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "C:\Users\Administrator_\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\autograd\__init__.py", line 147, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [15, 5]] is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

我只想在一个纪元后更新 GRU 中的 hidden_state ,但它不起作用!

如果您能帮上忙,我将不胜感激!

在 PyTorch 中,为一个 epoch 中的每次迭代创建计算图。 在每次迭代中,我们执行前向传递,计算 output w.r.t 到网络参数的导数,并更新参数以适应给定的示例。 执行反向传递后,图形将被释放以保存 memory。 在下一次迭代中,将创建一个全新的图并准备好进行反向传播。

因为计算图将在第一次反向传递后默认释放,如果您尝试第二次在同一个图上进行反向操作,您将遇到错误。 这就是弹出以下错误消息的原因:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time

资源

在您的情况下,在指定retain_graph=True后,您会看到:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [15, 5]] is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

当您尝试在前向传递中更新self.h时会出现此问题。 您没有inplace修改它,因为它是 grad 计算所必需的。 来源这个应该可以工作:

import torch


class GRU_model(torch.nn.Module):
    def __init__(self, device):
        super(GRU_model, self).__init__()
        self.GRU_1 = torch.nn.GRU(input_size=5, hidden_size=5)

    def forward(self, a, h):
        output, hh = self.GRU_1(a, h)
        return output, hh


if __name__ == '__main__':
    learn_rate = 1e-4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GRU_model(device).to(device=device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)

    def mse(output, target):
        loss = torch.mean((output - target)**2)
        return loss

    for i in range(10):
        a = torch.randn((1, 1, 5), device=device, dtype=torch.float, requires_grad=True)
        target = torch.randn((1, 1, 5), device=device)
        h = torch.randn((1, 1, 5), device=device)

        optimizer.zero_grad()
        output, h = model(a, h)
        loss = mse(output, target)

        loss.backward(retain_graph=True)
        optimizer.step()

真正的问题是 hidden_state 不应该参与梯度反向传播的计算。 因此,只需添加一行self.h = self.h.detach()如下,肯定会解决问题:

    def forward(self, a):
        output, self.h = self.GRU_1(a, self.h)
        self.h = self.h.detach()
        return output

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM