[英]How to concurrently run multiple branches in pytorch?
我试图在 pytorch 中构建一个具有多个分支的网络。 但是如何并行运行多个分支而不是一个一个运行它们?
不像tensorflow或keras,pytorch使用动态图,所以我不能事先定义并发处理。
我查找了一些类似的 pytorch 网络官方实现,如 InceptionNet,却发现pytorch 连续运行多个分支。
def _forward(self, x):
branch1x1 = self.branch1x1(x)
branch5x5 = self.branch5x5_1(x)
branch5x5 = self.branch5x5_2(branch5x5)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
return outputs
四个分支一个接一个运行,先是branch1x1,然后是branch5x5,还有branch3x3dbl,branch_pool。 然后输出存储它们的结果,稍后将它们连接起来。
这不会浪费性能吗? 我们该如何处理?
一般来说,只要使用pytorch提供的函数,就不必关心网络执行的性能。
正如评论中所指出的,对 gpu 的所有调用都是异步的。 只要调用不依赖于数据,它就会被执行。 因此,在您的情况下,您有多个分支。 Pytorch 将调度所有操作并根据数据依赖关系执行它们。 由于您不在分支之间共享数据,因此它们将并行执行。
所以在你的情况下
branch3x3dbl = self.branch3x3dbl_1(x)
branch1x1 = self.branch1x1(x)
branch3x3dbl = self.branch3x3dbl_1(x)
可能或多或少同时执行。 以下所有层都是一样的。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.