简体   繁体   中英

How to get rid of zombie processes using torch.multiprocessing.Pool (Python)

I am using torch.multiprocessing.Pool to speed up my NN in inference, like this:

import torch.multiprocessing as mp
mp = torch.multiprocessing.get_context('forkserver')

def parallel_predict(predict_func, sequences, args):
    predicted_cluster_ids = []
    pool = mp.Pool(args.num_workers, maxtasksperchild=1)
    out = pool.imap(
        func=functools.partial(predict_func, args=args),
        iterable=sequences,
        chunksize=1)
    for item in tqdm(out, total=len(sequences), ncols=85):
        predicted_cluster_ids.append(item)
    pool.close()
    pool.terminate()
    pool.join()
    return predicted_cluster_ids

Note 1) I am using imap because I want to be able to show a progress bar with tqdm .
Note 2) I tried with both forkserver and spawn but no luck. I cannot use other methods because of how they interact (poorly) with CUDA.
Note 3) I am using maxtasksperchild=1 and chunksize=1 so for each sequence in sequences it spawns a new process.
Note 4) Adding or removing pool.terminate() and pool.join() makes no difference.
Note 5) predict_func is a method of a class I created. I could also pass the whole model to parallel_predict but it does not change anything.

Everything works fine except the fact that after a while I run out of memory on the CPU (while on the GPU everything works as expected). Using htop to monitor memory usage I notice that, for every process I spawn with pool I get a zombie that uses 0.4% of the memory. They don't get cleared, so they keep using space. Still, parallel_predict does return the correct result and the computation goes on. My script is structured in a way that id does validation multiple times so next time parallel_predict is called the zombies add up.

This is what I get in htop : 在此处输入图片说明

Usually, these zombies get cleared after ctrl-c but in some rare cases I need to killall .

Is there some way I can force the Pool to close them?

UPDATE: I tried to kill the zombie processes using this:

def kill(pool):
    import multiprocessing
    import signal
    # stop repopulating new child
    pool._state = multiprocessing.pool.TERMINATE
    pool._worker_handler._state = multiprocessing.pool.TERMINATE
    for p in pool._pool:
        os.kill(p.pid, signal.SIGKILL)
    # .is_alive() will reap dead process
    while any(p.is_alive() for p in pool._pool):
        pass
    pool.terminate()

But it does not work. It gets stuck at pool.terminate()

UPDATE2: I tried to use the initializer arg in imap to catch signals like this:

def process_initializer():
    def handler(_signal, frame):
        print('exiting')
        exit(0)
    signal.signal(signal.SIGTERM, handler)


def parallel_predict(predict_func, sequences, args):
    predicted_cluster_ids = []
    with mp.Pool(args.num_workers, initializer=process_initializer, maxtasksperchild=1) as pool:
        out = pool.imap(
            func=functools.partial(predict_func, args=args),
            iterable=sequences,
            chunksize=1)
        for item in tqdm(out, total=len(sequences), ncols=85):
            predicted_cluster_ids.append(item)
        for p in pool._pool:
            os.kill(p.pid, signal.SIGTERM)
        pool.close()
        pool.terminate()
        pool.join()
    return predicted_cluster_ids

but again it does not free memory.

Ok, I have more insights to share with you. Indeed this is not a bug, it is actually the "supposed" behavior for the multiprocessing module in Python (torch.multiprocessing wraps it). What happens is that, although the Pool terminates all the processes, the memory is not released (given back to the OS). This is also stated in the documentation , though in a very confusing way. In the documentation it says that

Worker processes within a Pool typically live for the complete duration of the Pool's work queue

but also:

A frequent pattern found in other systems (such as Apache, mod_wsgi, etc) to free resources held by workers is to allow a worker within a pool to complete only a set amount of work before being exiting, being cleaned up and a new process spawned to replace the old one. The maxtasksperchild argument to the Pool exposes this ability to the end user

but the "clean up" does NOT happen.

To make things worse I found this post in which they recommend to use maxtasksperchild=1 . This increases the memory leak, because this way the number of zombies goes with the number of data points to be predicted, and since pool.close() does not free memory they add up.

This is very bad if you are using multiprocessing for example in validation. For every validation step I was reinitializing the pool but the memory didn't get freed from the previous iteration.

The SOLUTION here is to move pool = mp.Pool(args.num_workers) outside the training loop, so the pool does not get closed and reopened, and therefore it always reuses the same processes. NOTE: again remember to remove maxtasksperchild=1 and chunksize=1 .

I think this should be included in the best practices page.

BTW in my opinion this behavior of the multiprocessing library should be considered as a bug and should be fixed Python side (not Pytorch side)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM