Can anyone help me understand why this simple example of trying to speed up a for loop using python's multiprocessing
module produces unstable results? I use a Manager.List
to store the values from the child processes.
Clearly I'm doing at least one thing wrong. What would be the correct way to do this?
import numpy as np
import multiprocessing
from matplotlib import pyplot as plt
from functools import partial
from multiprocessing import Manager
def run_parallel(x_val, result):
val = np.arctan(x_val)
result.append(val)
def my_func(x_array, parallel=False):
if not parallel:
result = []
for k in x_array:
result.append(np.arctan(k))
return result
else:
manager = Manager()
m_result = manager.list()
pool = multiprocessing.Pool(4)
pool.map(partial(run_parallel, result=m_result), x_array)
return list(m_result)
test_x = np.linspace(0.1,1,50)
serial = my_func(test_x,parallel=False)
parallel = my_func(test_x,parallel=True)
plt.figure()
plt.plot(test_x, serial, label='serial')
plt.plot(test_x,parallel, label='parallel')
plt.legend(loc='best')
plt.show()
The output I'm getting looks like this
and it looks different every time this runs.
I added some print functions and it turned out that the order of elements from x_array is arbitrary... That's why it looks so weird. I think you should keep argument and value of arctan pairs and then order it by argument value
EDIT
I read more and it turned out that map
returns values in order... This works as you wanted:
import numpy as np
import multiprocessing
from matplotlib import pyplot as plt
from functools import partial
from multiprocessing import Manager
def run_parallel(x_val, result):
val = np.arctan(x_val)
return val
def my_func(x_array, parallel=False):
if not parallel:
result = []
for k in x_array:
result.append(np.arctan(k))
return result
else:
manager = Manager()
m_result = manager.list()
pool = multiprocessing.Pool(4)
x = pool.map(partial(run_parallel, result=m_result), x_array)
return list(x)
test_x = np.linspace(0.1,1,50)
parallel = my_func(test_x,parallel=True)
plt.figure()
plt.plot(test_x,parallel, label='parallel')
plt.legend(loc='best')
plt.show()
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.