[英]How can I systematically reuse the results of delayed functions in Dask?
I am working on building a computation graph with Dask.我正在使用 Dask 构建计算图。 Some of the intermediate values will be used multiple times, but I would like those calculations to only run once.
一些中间值将被多次使用,但我希望这些计算只运行一次。 I must be making a trivial mistake, because that's not what happens.
我一定是犯了一个小错误,因为事实并非如此。 Here is a minimal example:
这是一个最小的例子:
In [1]: import dask
dask.__version__
Out [1]: '1.0.0'
In [2]: class SumGenerator(object):
def __init__(self):
self.sources = []
def register(self, source):
self.sources += [source]
def generate(self):
return dask.delayed(sum)([s() for s in self.sources])
In [3]: sg = SumGenerator()
In [4]: @dask.delayed
def source1():
return 1.
@dask.delayed
def source2():
return 2.
@dask.delayed
def source3():
return 3.
In [5]: sg.register(source1)
sg.register(source1)
sg.register(source2)
sg.register(source3)
In [6]: sg.generate().visualize()
Sadly I am unable to post the resulting graph image, but basically I see two separate nodes for the function source1
that was registered twice.遗憾的是,我无法发布生成的图形图像,但基本上我看到了注册两次的 function
source1
的两个单独节点。 Therefore the function is called twice.因此 function 被调用了两次。 I would rather like to have it called once, the result remembered and added twice in the sum.
我宁愿让它调用一次,结果被记住并在总和中添加两次。 What would be the correct way to do that?
这样做的正确方法是什么?
You need to call the dask.delayed
decorator by passing the pure=True
argument.您需要通过传递
pure=True
参数来调用dask.delayed
装饰器。
From the dask delayed docs从dask 延迟的文档
delayed also accepts an optional keyword pure.
delay 也接受一个可选的关键字 pure。 If False, then subsequent calls will always produce a different Delayed
如果为 False,那么后续调用将始终产生不同的 Delayed
If you know a function is pure (output only depends on the input, with no global state), then you can set pure=True.
如果您知道 function 是纯的(输出仅取决于输入,没有全局状态),那么您可以设置 pure=True。
So using that所以使用它
import dask
class SumGenerator(object):
def __init__(self):
self.sources = []
def register(self, source):
self.sources += [source]
def generate(self):
return dask.delayed(sum)([s() for s in self.sources])
@dask.delayed(pure=True)
def source1():
return 1.
@dask.delayed(pure=True)
def source2():
return 2.
@dask.delayed(pure=True)
def source3():
return 3.
sg = SumGenerator()
sg.register(source1)
sg.register(source1)
sg.register(source2)
sg.register(source3)
sg.generate().visualize()
Output and Graph Output 和图表
Using print(dask.compute(sg.generate()))
gives (7.0,)
which is the same as the one you wrote but without the extra node as seen in the image.使用
print(dask.compute(sg.generate()))
得到(7.0,)
与您编写的相同,但没有图像中看到的额外节点。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.