![](/img/trans.png)
[英]RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead
[英]Pytorch: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead
调用tensor.numpy()
给出错误:
RuntimeError:无法在需要 grad 的变量上调用 numpy()。 请改用 var.detach().numpy() 。
tensor.cpu().detach().numpy()
给出了同样的错误。
import torch
tensor1 = torch.tensor([1.0,2.0],requires_grad=True)
print(tensor1)
print(type(tensor1))
tensor1 = tensor1.numpy()
print(tensor1)
print(type(tensor1))
这导致tensor1 = tensor1.numpy()
行出现完全相同的错误:
tensor([1., 2.], requires_grad=True)
<class 'torch.Tensor'>
Traceback (most recent call last):
File "/home/badScript.py", line 8, in <module>
tensor1 = tensor1.numpy()
RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.
Process finished with exit code 1
这是在您的错误消息中向您建议的,只需将var
替换为您的变量名
import torch
tensor1 = torch.tensor([1.0,2.0],requires_grad=True)
print(tensor1)
print(type(tensor1))
tensor1 = tensor1.detach().numpy()
print(tensor1)
print(type(tensor1))
按预期返回
tensor([1., 2.], requires_grad=True)
<class 'torch.Tensor'>
[1. 2.]
<class 'numpy.ndarray'>
Process finished with exit code 0
除了实际值定义之外,您还需要将张量转换为不需要梯度的另一个张量。 这个其他张量可以转换为 numpy 数组。 参照。 这个讨论.pytorch 帖子。 (我认为,更准确地说,为了从它的 pytorch Variable
包装器中取出实际的张量,需要这样做,请参见另一个讨论.pytorch 帖子)。
我有同样的错误信息,但它是用于在 matplotlib 上绘制散点图。
我可以从这个错误消息中得到两个步骤:
使用以下命令导入fastai.basics
库: from fastai.basics import *
如果您只使用torch
库,请记住将requires_grad
去掉:
with torch.no_grad(): (your code)
from torch.autograd import Variable
type(y) # <class 'torch.Tensor'>
y = Variable(y, requires_grad=True)
y = y.detach().numpy()
type(y) #<class 'numpy.ndarray'>
我在运行时期时遇到了这个问题,我将损失记录到了一个列表中
final_losses.append(loss)
一旦我跑遍了所有的时代,我想绘制输出图
plt.plot(range(epochs), final_loss)
plt.ylabel('RMSE Loss')
plt.xlabel('Epoch');
我在我的 Mac 上运行它,没有问题,但是,我需要在 Windows PC 上运行它,它产生了上面提到的错误。 所以,我检查了每个变量的类型。
Type(range(epochs)), type(final_losses)
范围,列表
好像应该没问题。
花了一点时间才意识到 final_losses 列表是张量列表。 然后我将它们转换为带有新列表变量 fi_los 的实际列表。
fi_los = [fl.item() for fl in final_losses ]
plt.plot(range(epochs), fi_los)
plt.ylabel('RMSE Loss')
plt.xlabel('Epoch');
成功!
最好的解决方案是使用torch.no_grad():
上下文管理器,它在本地禁用梯度跟踪。
只需在此联系人管理器中编写您的代码,例如:
with torch.no_grad():
graph_x = some_list_of_numbers
graph_y = some_list_of_tensors
plt.plot(graph_x, graph_y)
plt.show()
写吧 :
y = tensor.detach().numpy()
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.