繁体   English   中英

如何在pytorch中同时运行多个分支?

[英]How to concurrently run multiple branches in pytorch?

我试图在 pytorch 中构建一个具有多个分支的网络 但是如何并行运行多个分支而不是一个一个运行它们

不像tensorflow或keras,pytorch使用动态图,所以我不能事先定义并发处理。

我查找了一些类似的 pytorch 网络官方实现,如 InceptionNet,却发现pytorch 连续运行多个分支

来自inception.py

def _forward(self, x):
    branch1x1 = self.branch1x1(x)

    branch5x5 = self.branch5x5_1(x)
    branch5x5 = self.branch5x5_2(branch5x5)

    branch3x3dbl = self.branch3x3dbl_1(x)
    branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
    branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

    branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
    branch_pool = self.branch_pool(branch_pool)

    outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
    return outputs

四个分支一个接一个运行,先是branch1x1,然后是branch5x5,还有branch3x3dbl,branch_pool。 然后输出存储它们的结果,稍后将它们连接起来。

这不会浪费性能吗? 我们该如何处理?

一般来说,只要使用pytorch提供的函数,就不必关心网络执行的性能。

正如评论中所指出的,对 gpu 的所有调用都是异步的。 只要调用不依赖于数据,它就会被执行。 因此,在您的情况下,您有多个分支。 Pytorch 将调度所有操作并根据数据依赖关系执行它们。 由于您不在分支之间共享数据,因此它们将并行执行。

所以在你的情况下

branch3x3dbl = self.branch3x3dbl_1(x)
branch1x1 = self.branch1x1(x)
branch3x3dbl = self.branch3x3dbl_1(x)

可能或多或少同时执行。 以下所有层都是一样的。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM