繁体   English   中英

Pytorch:无法在需要 grad 的变量上调用 numpy()。 改用 var.detach().numpy()

[英]Pytorch: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead

调用tensor.numpy()给出错误:

RuntimeError:无法在需要 grad 的变量上调用 numpy()。 请改用 var.detach().numpy() 。

tensor.cpu().detach().numpy()给出了同样的错误。

错误重现

import torch

tensor1 = torch.tensor([1.0,2.0],requires_grad=True)

print(tensor1)
print(type(tensor1))

tensor1 = tensor1.numpy()

print(tensor1)
print(type(tensor1))

这导致tensor1 = tensor1.numpy()行出现完全相同的错误:

tensor([1., 2.], requires_grad=True)
<class 'torch.Tensor'>
Traceback (most recent call last):
  File "/home/badScript.py", line 8, in <module>
    tensor1 = tensor1.numpy()
RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.

Process finished with exit code 1

通用解决方案

这是在您的错误消息中向您建议的,只需将var替换为您的变量名

import torch

tensor1 = torch.tensor([1.0,2.0],requires_grad=True)

print(tensor1)
print(type(tensor1))

tensor1 = tensor1.detach().numpy()

print(tensor1)
print(type(tensor1))

按预期返回

tensor([1., 2.], requires_grad=True)
<class 'torch.Tensor'>
[1. 2.]
<class 'numpy.ndarray'>

Process finished with exit code 0

一些解释

除了实际值定义之外,您还需要将张量转换为不需要梯度的另一个张量。 这个其他张量可以转换为 numpy 数组。 参照。 这个讨论.pytorch 帖子 (我认为,更准确地说,为了从它的 pytorch Variable包装器中取出实际的张量,需要这样做,请参见另一个讨论.pytorch 帖子)。

我有同样的错误信息,但它是用于在 matplotlib 上绘制散点图。

我可以从这个错误消息中得到两个步骤:

  1. 使用以下命令导入fastai.basics库: from fastai.basics import *

  2. 如果您只使用torch库,请记住将requires_grad去掉:

     with torch.no_grad(): (your code)

对于现有张量

from torch.autograd import Variable

type(y)  # <class 'torch.Tensor'>

y = Variable(y, requires_grad=True)
y = y.detach().numpy()

type(y)  #<class 'numpy.ndarray'>

我在运行时期时遇到了这个问题,我将损失记录到了一个列表中

final_losses.append(loss)

一旦我跑遍了所有的时代,我想绘制输出图

plt.plot(range(epochs), final_loss)
plt.ylabel('RMSE Loss')
plt.xlabel('Epoch');

我在我的 Mac 上运行它,没有问题,但是,我需要在 Windows PC 上运行它,它产生了上面提到的错误。 所以,我检查了每个变量的类型。

Type(range(epochs)), type(final_losses)

范围,列表

好像应该没问题。

花了一点时间才意识到 final_losses 列表是张量列表。 然后我将它们转换为带有新列表变量 fi_los 的实际列表。

fi_los = [fl.item() for fl in final_losses ]
plt.plot(range(epochs), fi_los)
plt.ylabel('RMSE Loss')
plt.xlabel('Epoch');

成功!

最好的解决方案是使用torch.no_grad():上下文管理器,它在本地禁用梯度跟踪。

只需在此联系人管理器中编写您的代码,例如:

with torch.no_grad():
   graph_x = some_list_of_numbers
   graph_y = some_list_of_tensors

   plt.plot(graph_x, graph_y)
   plt.show()

写吧 :

y = tensor.detach().numpy()

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM