简体   繁体   中英

Python Multiprocessing Skip Child Segfault

I'm trying to use multiprocessing for a function that can potentially return a segfault (I have no control over this ATM). In cases where the child process hits a segfault, I want only that child to fail, but all other child tasks to continue/return their results.

I've already switched from multiprocessing.Pool to concurrent.futures.ProcessPoolExecutor avoid the issue of the child process hanging forever (or until an arbitrary timeout) as documented in this bug: https://bugs.python.org/issue22393 .

However the issue I face now, is that when the first child task hits a segfault, all in-flight child processes get marked as broken ( concurrent.futures.process.BrokenProcessPool ).

Is there a way to only mark actually broken child processes as broken?

Code I'm running in Python 3.7.4 :

import concurrent.futures
import ctypes
from time import sleep


def do_something(x):
    print(f"{x}; in do_something")
    sleep(x*3)
    if x == 2:
        # raise a segmentation fault internally
        return x, ctypes.string_at(0)
    return x, x-1


nums = [1, 2, 3, 1.5]
executor = concurrent.futures.ProcessPoolExecutor()
result_futures = []
for num in nums:
    # Using submit with a list instead of map lets you get past the first exception
    # Example: https://stackoverflow.com/a/53346191/7619676
    future = executor.submit(do_something, num)
    result_futures.append(future)

# Wait for all results
concurrent.futures.wait(result_futures)

# After a segfault is hit for any child process (i.e. is "terminated abruptly"), the process pool becomes unusable
# and all running/pending child processes' results are set to broken
for future in result_futures:
    try:
        print(future.result())
    except concurrent.futures.process.BrokenProcessPool:
        print("broken")

Result:

(1, 0)
broken
broken
(1.5, 0.5)

Desired result:

(1, 0)
broken
(3, 2)
(1.5, 0.5)

Based on @Richard Sheridan's answer, I ended up using the code below. This version doesn't require setting a timeout, which is something I couldn't do for my use case.

import ctypes
import multiprocessing
from typing import List
from time import sleep


def do_something(x, result):
    print(f"{x} starting")
    sleep(x * 3)
    if x == 2:
        # raise a segmentation fault internally
        y = ctypes.string_at(0)
    y = x
    print(f"{x} done")
    results_queue.put(y)

def wait_for_process_slot(
    processes: List,
    concurrency: int = multiprocessing.cpu_count() - 1,
    wait_sec: int = 1,
) -> int:
    """Blocks main process if `concurrency` processes are already running.

    Alternative to `multiprocessing.Semaphore.acquire`
    useful for when child processes might fail and not be able to signal.
    Relies instead on the main's (parent's) tracking of `multiprocessing.Process`es.

    """
    counter = 0
    while True:
        counter = sum([1 for i, p in processes.items() if p.is_alive()])
        if counter < concurrency:
            return counter
        sleep(wait_sec)


if __name__ == "__main__":
    # "spawn" results in an OSError b/c pickling a segfault fails?
    ctx = multiprocessing.get_context()
    manager = ctx.Manager()
    results_queue = manager.Queue(maxsize=-1)

    concurrency = multiprocessing.cpu_count() - 1  # reserve 1 CPU for waiting
    nums = [3, 1, 2, 1.5]
    all_processes = {}
    for idx, num in enumerate(nums):
        num_running_processes = wait_for_process_slot(all_processes, concurrency)

        p = ctx.Process(target=do_something, args=(num, results_queue), daemon=True)
        all_processes.update({idx: p})
        p.start()

    # Wait for the last batch of processes not blocked by wait_for_process_slot to finish
    for p in all_processes.values():
        p.join()

    # Check last batch of processes for bad processes
    # Relies on all processes having finished (the p.joins above)
    bad_nums = [idx for idx, p in all_processes.items() if p.exitcode != 0]

multiprocessing.Pool and concurrent.futures.ProcessPoolExecutor both make assumptions about how to handle the concurrency of the interactions between the workers and the main process that are violated if any one process is killed or segfaults, so they do the safe thing and mark the whole pool as broken. To get around this, you will need to build up your own pool with different assumptions directly using multiprocessing.Process instances.

This might sound intimidating but a list and a multiprocessing.Manager will get you pretty far:

import multiprocessing
import ctypes
import queue
from time import sleep

def do_something(job, result):
    while True:
        x=job.get()
        print(f"{x}; in do_something")
        sleep(x*3)
        if x == 2:
            # raise a segmentation fault internally
            return x, ctypes.string_at(0)
        result.put((x, x-1))

nums = [1, 2, 3, 1.5]

if __name__ == "__main__":
    # you ARE using the spawn context, right?
    ctx = multiprocessing.get_context("spawn")
    manager = ctx.Manager()
    job_queue = manager.Queue(maxsize=-1)
    result_queue = manager.Queue(maxsize=-1)
    pool = [
        ctx.Process(target=do_something, args=(job_queue, result_queue), daemon=True)
        for _ in range(multiprocessing.cpu_count())
    ]
    for proc in pool:
        proc.start()
    for num in nums:
        job_queue.put(num)
    try:
        while True:
            # Timeout is our only signal that no more results coming
            print(result_queue.get(timeout=10))
    except queue.Empty:
        print("Done!")
    print(pool)  # will see one dead Process 

This "Pool" is a little inflexible, and you will probably want to customize it for your application's specific needs. But you can definitely skip right over segfaulting workers.

When I went down this rabbit hole, where I was interested in cancelling specific submissions to a worker pool, I eventually wound up writing a whole library to integrate into Trio async apps: trio-parallel . Hopefully you won't need to go that far!

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM