[英]how to pass parameters other than data through pool.imap() function for multiprocessing in python?
我正在研究高光谱图像。 为了减少图像中的噪声,我使用 pywt package 进行小波变换。 当我正常执行此操作(串行处理)时,它工作顺利。 但是当我尝试使用多核对图像进行小波变换来实现并行处理时,我必须传递某些参数,例如
但是我无法使用池 object 传递这些参数,当我使用 pool.imap() 时,我只能将数据作为参数传递。 但是当我使用 pool.apply_async() 时,它需要更多的时间,而且 output 的顺序也不一样。 我在这里添加代码以供参考:
import matplotlib.pyplot as plt
import numpy as np
import multiprocessing as mp
import os
import time
from math import log10, sqrt
import pywt
import tifffile
def spec_trans(d,wav_fam,threshold_val,thresh_type):
data=np.array(d,dtype=np.float64)
data_dec=decomposition(data,wav_fam)
data_t=thresholding(data_dec,threshold_val,thresh_type)
data_rec=reconstruction(data_t,wav_fam)
return data_rec
if __name__ == '__main__':
#input
X=tifffile.imread('data/Classification/university.tif')
#take paramaters
threshold_val=float(input("Enter the value for image thresholding: "))
print("The available wavelet functions:",pywt.wavelist())
wav_fam=input("Choose a wavelet function for transformation: ")
threshold_type=['hard','soft']
print("The available wavelet functions:",threshold_type)
thresh_type=input("Choose a type for threshholding technique: ")
start=time.time()
p = mp.Pool(4)
jobs=[]
for dataBand in xmp:
jobs.append(p.apply_async(spec_trans,args=(dataBand,wav_fam,threshold_val,thresh_type)))
transformedX=[]
for jobBit in jobs:
transformedX.append(jobBit.get())
end=time.time()
p.close()
此外,当我使用“软”技术进行阈值处理时,我面临以下错误:
C:\Users\Sawon\anaconda3\lib\site-packages\pywt\_thresholding.py:25: RuntimeWarning: invalid value encountered in multiply
thresholded = data * thresholded
串行执行和并行执行的结果或多或少是相同的。 但在这里我得到的结果略有不同。 任何修改代码的建议都会有所帮助谢谢
[这不是问题的直接答案,而是比通过小评论框尝试以下更清晰的后续查询]
作为快速检查,将迭代器计数器传递给 spec_trans 并将其返回(以及您的结果) - 并将其推送到单独的列表中,transformXseq 或其他东西 - 然后与您的输入序列进行比较。 IE
def spec_trans(d,wav_fam,threshold_val,thresh_type, iCount):
data=np.array(d,dtype=np.float64)
data_dec=decomposition(data,wav_fam)
data_t=thresholding(data_dec,threshold_val,thresh_type)
data_rec=reconstruction(data_t,wav_fam)
return data_rec, iCount
然后在 main
jobs=[]
iJobs = 0
for dataBand in xmp:
jobs.append(p.apply_async(spec_trans,args=(dataBand,wav_fam,threshold_val,thresh_type, iJobs)))
iJobs = iJobs + 1
transformedX=[]
transformedXseq=[]
for jobBit in jobs:
res = jobBit.get()
transformedX.append(res[0])
transformedXseq.append(res[1])
...并检查列表 transformXseq 以查看您是否按照提交的顺序收集了作业。 它应该匹配!
假设wav_fam
、 threshold_val
和thresh_type
不会因调用而异,首先将这些 arguments 安排为第一个arguments 到工人spec_trans
:
def spec_trans(wav_fam, threshold_val, thresh_type, d):
现在我看不到您在池创建块中的哪个位置定义xmp
,但大概这是一个可迭代的。 您需要按如下方式修改此代码:
from functools import partial
def compute_chunksize(pool_size, iterable_size):
chunksize, remainder = divmod(iterable_size, 4 * pool_size)
if remainder:
chunksize += 1
return chunksize
if __name__ == '__main__':
X=tifffile.imread('data/Classification/university.tif')
#take paramaters
threshold_val=float(input("Enter the value for image thresholding: "))
print("The available wavelet functions:",pywt.wavelist())
wav_fam=input("Choose a wavelet function for transformation: ")
threshold_type=['hard','soft']
print("The available wavelet functions:",threshold_type)
thresh_type=input("Choose a type for threshholding technique: ")
start=time.time()
p = mp.Pool(4)
# first 3 arguments to spec_trans will be wav_fam, threshold_val and thresh_type
worker = partial(spec_trans, wav_fam, threshold_val, thresh_type)
suitable_chunksize = compute_chunksize(4, len(xmp))
transformedX = list(p.imap(worker, xmp, chunksize=suitable_chunksize))
end=time.time()
要获得比使用apply_async
更高的性能,您必须使用imap
的“合适的块大小”值。 Function compute_chunksize
可用于根据池的大小(即 4)和传递给imap
的可迭代的大小(即len(xmp)
来计算这样的值。 如果xmp
的大小足够小,以至于计算出的 chunksize 值为 1,我真的看不出imap
会比apply_async
性能显着提高。
当然,您也可以只使用:
transformedX = p.map(worker, xmp)
并让池计算自己合适的块大小。 当可迭代对象非常大并且还不是列表时, imap
比map
具有优势。 对于map
来计算合适的块大小,它首先必须将可迭代对象转换为列表以获取其长度,这可能是 memory 效率低下。 但是如果您知道可迭代对象的长度(或近似长度),那么通过使用 imap 您可以显式设置块大小,而无需将可迭代对象转换为列表。 与imap_unordered
相比, map
的另一个优势是,您可以在单个任务可用时处理它们的结果,而使用map
,您只有在所有提交的任务完成后才能获得结果。
更新
如果您想捕获提交给您的工作人员 function 的单个任务可能引发的异常,请坚持使用imap
,并使用以下代码迭代imap
返回的结果:
#transformedX = list(p.imap(worker, xmp, chunksize=suitable_chunksize))
transformedX = []
results = p.imap(worker, xmp, chunksize=suitable_chunksize)
import traceback
while True:
try:
result = next(results)
except StopIteration: # no more results
break
except Exception as e:
print('Exception occurred:', e)
traceback.print_exc() # print stacktrace
else:
transformedX.append(result)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.