[英]Pytorch: how to get all the tensors in a graph
您可以在运行时访问整个计算图。 为此,您可以使用钩子。 这些是插入到nn.Module
的函数,用于推理和反向传播。
在推理时,您可以使用register_forward_hook
插入一个钩子。 对于反向传播,您可以使用register_backward_hook
(注意:在1.8.0版本中,此功能将被弃用,以支持register_full_backward_hook
)。
使用这两个函数,您基本上可以访问计算图上的任何张量。 是否要打印所有张量、打印形状,甚至插入断点进行调查,这完全取决于您。
这是一个可能的实现:
def forward_hook(module, input, output):
# ...
参数input
由 PyTorch 作为元组传递,并将包含传递给挂钩模块的 forward 函数的所有参数。
def backward_hook(module, grad_input, grad_output):
# ...
对于后向钩子, grad_input
和grad_output
都将是tuple ,并且根据模型的层具有不同的形状。
然后你可以在任何现有的nn.Module
上挂钩这些回调。 例如,您可以遍历模型中的所有子模块:
for module in model.children():
module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)
要获取模块的名称,您可以包装钩子以将名称括起来并在模型的named_modules
上循环:
def forward_hook(name):
def hook(module, x, y):
print('%s: %s -> %s' % (name, list(x[0].size()), list(y.size())))
return hook
for name, module in model.named_children():
module.register_forward_hook(forward_hook(name))
可以在推理时打印以下内容:
fc1: [1, 100] -> [1, 10]
fc2: [1, 10] -> [1, 5]
fc3: [1, 5] -> [1, 1]
正如我所说,向后传递有点复杂。 我只能建议您探索和试验pdb
:
def backward_hook(module, grad_input, grad_output):
pdb.set_trace()
至于模型的参数,您可以通过调用module.parameters
轻松访问两个钩子中给定模块的参数。 这将返回一个生成器。
我只能祝你在探索你的模型时好运!
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.