[英]Python multiprocessing pool.map hangs after sklearn function call
我正在尝试使用multiprocessing
在二维数组和二维数组集合之间执行一些计算。 假设我有一个矩阵mat1
和一个矩阵集合test
,我想在其中计算mat1
和test
元素之间的所有矩阵乘法。 我使用多处理并行运行计算,因为test
的大小非常大。 但是,我注意到即使是一个小test
,计算也永远不会完成。 具体来说,该程序似乎永远不会完成矩阵乘法计算。 似乎是对特定sklearn
函数的调用导致了该问题。 我编写了以下代码来说明这一点(我使用partial
而不是starmap
因为我想稍后使用imap
和tqdm
):
from multiprocessing import Pool
from functools import partial
import numpy as np
import sklearn as sklearn
def bar(y, x):
# this does not seem to complete
mul = x @ y.T
# so this does not print
print('done')
return mul
def foo():
mat1 = np.ones((1000000, 14))
test = (np.ones((1,14)), np.ones((1,14)))
# these will finish
print(mat1 @ test[0].T)
print(mat1 @ test[1].T)
with Pool(6) as pool:
result = pool.map(partial(bar, x=mat1), test
p.close()
p.join()
if __name__ == "__main__":
# Causes the hang
sklearn.metrics.pairwise.rbf_kernel(np.ones((9000, 14)),
np.ones((9000, 14)))
foo()
注意:对于那些不熟悉partial
,这是来自文档:
functools.partial(func[,*args][, **keywords])
返回一个新的部分对象,当调用该对象时,其行为类似于使用位置参数 args 和关键字参数关键字调用的 func 。
我被迫手动停止执行,否则它将永远运行。 我没有正确使用multiprocessing
吗?
对于那些感兴趣的人,可以在下面找到强制停止后的完整回溯:
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) <ipython-input-18-6c073b574e37> in <module>
8
9 sklearn.metrics.pairwise.rbf_kernel(np.ones((9000, 14)), np.ones((9000, 14)))
---> 10 foo()
11
<ipython-input-17-d183fc19ae3c> in foo()
11 with Pool(6) as pool:
12 # this will not finish
---> 13 result = pool.map(partial(bar, x=mat1), test)
14 p.close()
15 p.join()
~/anaconda3/lib/python3.7/multiprocessing/pool.py in map(self, func, iterable, chunksize)
266 in a list that is returned.
267 '''
--> 268 return self._map_async(func, iterable, mapstar, chunksize).get()
269
270 def starmap(self, func, iterable, chunksize=None):
~/anaconda3/lib/python3.7/multiprocessing/pool.py in get(self, timeout)
649
650 def get(self, timeout=None):
--> 651 self.wait(timeout)
652 if not self.ready():
653 raise TimeoutError
~/anaconda3/lib/python3.7/multiprocessing/pool.py in wait(self, timeout)
646
647 def wait(self, timeout=None):
--> 648 self._event.wait(timeout)
649
650 def get(self, timeout=None):
~/anaconda3/lib/python3.7/threading.py in wait(self, timeout)
550 signaled = self._flag
551 if not signaled:
--> 552 signaled = self._cond.wait(timeout)
553 return signaled
554
~/anaconda3/lib/python3.7/threading.py in wait(self, timeout)
294 try: # restore state no matter what (e.g., KeyboardInterrupt)
295 if timeout is None:
--> 296 waiter.acquire()
297 gotit = True
298 else:
KeyboardInterrupt:
更新1:
经过更多的调试,我发现了一些奇怪的东西。 实现sokato的代码后,我设法修复了这个例子。 但是,在main()
foo()
之前调用以下sklearn
函数时,我可以再次触发该问题:
sklearn.metrics.pairwise.rbf_kernel(np.ones((9000, 14)), np.ones((9000, 14)))
我已经更新了原始帖子以反映这一点。
您需要关闭多处理池。 例如
def bar(y, x):
# this does not seem to complete
mul = x @ y.T
# so this does not print
print('done')
return mul
def foo():
mat1 = np.ones((1000000, 14))
test = (np.ones((1,14)), np.ones((1,14)))
with Pool(5) as p:
# this will not finish
result = p.map(partial(bar, x=mat1), test)
p.close()
if __name__ == "__main__":
foo()
为了适合您的确切语法,您可以这样做
pool = Pool(6)
result = pool.map(partial(bar, x=mat1), test)
pool.close()
如果您有兴趣了解更多信息,我鼓励您查看文档。 https://docs.python.org/3.4/library/multiprocessing.html?highlight=process#multiprocessing.pool.Pool
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.