繁体   English   中英

Pytorch:如何获取图中的所有张量

[英]Pytorch: how to get all the tensors in a graph

我想访问图形的所有张量实例。 例如,我可以检查张量是否分离或者我可以检查大小。 它可以在tensorflow 中完成。

想要图形的可视化。

您可以在运行时访问整个计算图。 为此,您可以使用钩子。 这些是插入到nn.Module的函数,用于推理和反向传播。

在推理时,您可以使用register_forward_hook插入一个钩子。 对于反向传播,您可以使用register_backward_hook (注意:在1.8.0版本中,此功能将被弃用,以支持register_full_backward_hook )。

使用这两个函数,您基本上可以访问计算图上的任何张量。 是否要打印所有张量、打印形状,甚至插入断点进行调查,这完全取决于您。

这是一个可能的实现:

def forward_hook(module, input, output):
    # ...

参数input由 PyTorch 作为元组传递,并将包含传递给挂钩模块的 forward 函数的所有参数。

def backward_hook(module, grad_input, grad_output):
    # ...

对于后向钩子, grad_inputgrad_output都将是tuple ,并且根据模型的层具有不同的形状。

然后你可以在任何现有的nn.Module上挂钩这些回调。 例如,您可以遍历模型中的所有子模块:

for module in model.children():
    module.register_forward_hook(forward_hook)
    module.register_backward_hook(backward_hook)

要获取模块的名称,您可以包装钩子以将名称括起来并在模型的named_modules上循环:

def forward_hook(name):
    def hook(module, x, y):
        print('%s: %s -> %s' % (name, list(x[0].size()), list(y.size())))
    return hook

for name, module in model.named_children():
    module.register_forward_hook(forward_hook(name))

可以在推理时打印以下内容:

fc1: [1, 100] -> [1, 10]
fc2: [1, 10] -> [1, 5]
fc3: [1, 5] -> [1, 1]

正如我所说,向后传递有点复杂。 我只能建议您探索和试验pdb

def backward_hook(module, grad_input, grad_output):
    pdb.set_trace()

至于模型的参数,您可以通过调用module.parameters轻松访问两个钩子中给定模块的参数。 这将返回一个生成器。

我只能祝你在探索你的模型时好运!

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM