简体   繁体   中英

Parallelize tree creation with dask

I need help about a problem that I'm pretty sure dask can solve. But I don't know how to tackle it.

I need to construct a tree recursively.

For each node if a criterion is met a computation ( compute_val ) is done else 2 new childs are created. The same treament is performed on the childs ( build ). Then if all the childs of node had performed a computation we can proceed to a merge ( merge ). The merge can perform a fusion of the childs (if they both meet a criterion) or nothing. For the moment I was able to parallelize only the first level and I don't know which tools of dask I should use to be more effective. This is a simplified MRE sequential of what I want to achieve:

import numpy as np
import time

class Node:
    def __init__(self, level):
        self.level = level
        self.val = None

def merge(node, childs):
    values = [child.val for child in childs]
    if all(values) and sum(values)<0.1:
        node.val = np.mean(values)
    else:
        node.childs = childs
    return node        

def compute_val():
    time.sleep(0.1)
    return np.random.rand(1)

def build(node):
    print(node.level)
    if (np.random.rand(1) < 0.1 and node.level>1) or node.level>5:
        node.val = compute_val()
    else:
        childs = [build(Node(level=node.level+1)) for _ in range(2)]
        node = merge(node, childs)
    return node

tree = build(Node(level=0))

As I understand, the way you tackle recursion (or any dynamic computation) is to create tasks within a task.

I was experimenting with something similar, so below is my 5 minute illustrative solution. You'd have to optimise it according to characteristics of the algorithm.

Keep in mind that tasks add overhead, so you'd want to chunk the computations for optimal results.

Relevant doc:

Api reference:

import numpy as np
import time
from dask.distributed import Client, worker_client

# Create a dask client
# For convenience, I'm creating a localcluster.
client = Client(threads_per_worker=1, n_workers=8)
client

class Node:
    def __init__(self, level):
        self.level = level
        self.val = None
        self.childs = None   # This was missing

def merge(node, childs):
    values = [child.val for child in childs]
    if all(values) and sum(values)<0.1:
        node.val = np.mean(values)
    else:
        node.childs = childs
    return node        

def compute_val():
    time.sleep(0.1)            # Is this required.
    return np.random.rand(1)

def build(node):
    print(node.level)
    if (np.random.rand(1) < 0.1 and node.level>1) or node.level>5:
        node.val = compute_val()
    else:
        with worker_client() as client:
            child_futures = [client.submit(build, Node(level=node.level+1)) for _ in range(2)]
            childs = client.gather(child_futures)
        node = merge(node, childs)
    return node

tree_future = client.submit(build, Node(level=0))
tree = tree_future.result()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM