简体   繁体   English

Pytorch 如何构建计算图

[英]How does Pytorch build the computation graph

Here is example pytorch code from the website:这是来自网站的示例 pytorch 代码:

class Net(nn.Module):

def __init__(self):
    super(Net, self).__init__()
    # 1 input image channel, 6 output channels, 3x3 square convolution
    # kernel
    self.conv1 = nn.Conv2d(1, 6, 3)
    self.conv2 = nn.Conv2d(6, 16, 3)
    # an affine operation: y = Wx + b
    self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

def forward(self, x):
    # Max pooling over a (2, 2) window
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # If the size is a square you can only specify a single number
    x = F.max_pool2d(F.relu(self.conv2(x)), 2)
    x = x.view(-1, self.num_flat_features(x))
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

In the forward function, we simply apply a series of transformations to x, but never explicitly define which objects are part of that transformation.在 forward 函数中,我们只是对 x 应用一系列转换,但从未明确定义哪些对象是该转换的一部分。 Yet when computing the gradient and updating the weights, Pytorch 'magically' knows which weights to update and how the gradient should be calculated.然而,在计算梯度和更新权重时,Pytorch“神奇地”知道要更新哪些权重以及应该如何计算梯度。

How does this process work?这个过程是如何运作的? Is there code analysis going on, or something else that I am missing?是否正在进行代码分析,或者我遗漏了什么?

Yes, there is implicit analysis on forward pass.是的,对前向传递有隐式分析。 Examine the result tensor, there is thingie like grad_fn= <CatBackward> , that's a link, allowing you to unroll the whole computation graph.检查结果张量,有像grad_fn= <CatBackward>这样的东西,这是一个链接,允许您展开整个计算图。 And it is built during real forward computation process, no matter how you defined your network module, object oriented with 'nn' or 'functional' way.它是在真正的前向计算过程中构建的,无论您如何定义网络模块,以“nn”或“功能”方式面向对象。

You can exploit this graph for net analysis, as torchviz do here: https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py您可以利用此图进行网络分析,就像torchviz在这里torchviz那样: https : //github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM