简体   繁体   中英

Multiprocessing in Python: Is there a way to use pool.imap without accumulating memory?

I am using the multiprocessing module in Python to train neural networks with keras in parallel, using a Pool(processes = 4) object with imap . This steadily uses more and more memory after every "cycle", ie every 4 processes, until it finally crashes.

I used the memory_profiler module to track my memory usage over time, training 12 networks. Here's using vanilla imap :香草

If I put maxtasksperchild = 1 in Pool : 1taskperchild

If I use imap(chunksize = 3) :大块

In the latter case, where everything works out fine, I'm only sending off a single batch to every process in the pool, so it seems that the problem is that the processes carry information about previous batches. If so, can I force the pool to not do that?

Even though the chunks solution seems to work I'd rather not use that, because

  • I'd like to track progress using the tqdm module, and in the chunks case it will only update after every chunk, which effectively means it won't really track anything at all, as all the chunks finish at the same time (in this example)
  • Currently all networks take exactly the same time to train, but I'd like to enable the possibility of them having separate training times, where the chunks solution would then potentially cause one process to get all the long training times.

Here's a code snippet in the vanilla case. In the other two cases I just changed the maxtasksperchild parameter in Pool , and the chunksize parameter in imap :

def train_network(network):
    (...)
    return score

pool = Pool(processes = 4)
scores = pool.imap(train_network, networks)
scores = tqdm(scores, total = networks.size)

for (network, score) in zip(networks, scores):
    network.score = score

pool.close()
pool.join()

Unfortunaly, multiprocessing module in python come with a great expense. data is mostly not shared between processes and need to be replicated. This will change starting from python 3.8.

https://docs.python.org/3.8/library/multiprocessing.shared_memory.html

Although, the official release of python 3.8 is on 21 October 2019, you can already download it on github

I came up with a solution that seems to work. I ditched the pool and made my own simple queuing system. Aside from not increasing (it does increase ever so slightly though, but I think that's me storing some dictionaries as log), it even consumes less memory than the chunks solution above:

映射队列

I have no idea why that's the case. Perhaps the Pool objects just take up a lot of memory? Anyway, here's my code:

def train_network(network):
    (...)
    return score

# Define queues to organise the parallelising
todo = mp.Queue(size = networks.size + 4)
done = mp.Queue(size = networks.size)

# Populate the todo queue
for idx in range(networks.size):
    todo.put(idx)

# Add -1's which will be an effective way of checking
# if all todo's are finished
for _ in range(4):
    todo.put(-1)

def worker(todo, done):
    ''' Network scoring worker. '''
    from queue import Empty
    while True:
        try:
            # Fetch the next todo
            idx = todo.get(timeout = 1)
        except Empty:
            # The queue is never empty, so the silly worker has to go
            # back and try again
            continue

        # If we have reached a -1 then stop
        if idx == -1:
            break
        else:
            # Score the network and store it in the done queue
            score = train_network(networks[idx])
            done.put((idx, score))

# Construct our four processes
processes = [mp.Process(target = worker,
    args = (todo, done)) for _ in range(4)]

# Daemonise the processes, which closes them when
# they finish, and start them
for p in processes:
    p.daemon = True
    p.start()

# Set up the iterable with all the scores, and set
# up a progress bar
idx_scores = (done.get() for _ in networks)
pbar = tqdm(idx_scores, total = networks.size)

# Compute all the scores in parallel
for (idx, score) in pbar:
    networks[idx].score = score

# Join up the processes and close the progress bar
for p in processes:
    p.join()
pbar.close()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM