[英]how to pass parameters other than data through pool.imap() function for multiprocessing in python?
我正在研究高光譜圖像。 為了減少圖像中的噪聲,我使用 pywt package 進行小波變換。 當我正常執行此操作(串行處理)時,它工作順利。 但是當我嘗試使用多核對圖像進行小波變換來實現並行處理時,我必須傳遞某些參數,例如
但是我無法使用池 object 傳遞這些參數,當我使用 pool.imap() 時,我只能將數據作為參數傳遞。 但是當我使用 pool.apply_async() 時,它需要更多的時間,而且 output 的順序也不一樣。 我在這里添加代碼以供參考:
import matplotlib.pyplot as plt
import numpy as np
import multiprocessing as mp
import os
import time
from math import log10, sqrt
import pywt
import tifffile
def spec_trans(d,wav_fam,threshold_val,thresh_type):
data=np.array(d,dtype=np.float64)
data_dec=decomposition(data,wav_fam)
data_t=thresholding(data_dec,threshold_val,thresh_type)
data_rec=reconstruction(data_t,wav_fam)
return data_rec
if __name__ == '__main__':
#input
X=tifffile.imread('data/Classification/university.tif')
#take paramaters
threshold_val=float(input("Enter the value for image thresholding: "))
print("The available wavelet functions:",pywt.wavelist())
wav_fam=input("Choose a wavelet function for transformation: ")
threshold_type=['hard','soft']
print("The available wavelet functions:",threshold_type)
thresh_type=input("Choose a type for threshholding technique: ")
start=time.time()
p = mp.Pool(4)
jobs=[]
for dataBand in xmp:
jobs.append(p.apply_async(spec_trans,args=(dataBand,wav_fam,threshold_val,thresh_type)))
transformedX=[]
for jobBit in jobs:
transformedX.append(jobBit.get())
end=time.time()
p.close()
此外,當我使用“軟”技術進行閾值處理時,我面臨以下錯誤:
C:\Users\Sawon\anaconda3\lib\site-packages\pywt\_thresholding.py:25: RuntimeWarning: invalid value encountered in multiply
thresholded = data * thresholded
串行執行和並行執行的結果或多或少是相同的。 但在這里我得到的結果略有不同。 任何修改代碼的建議都會有所幫助謝謝
[這不是問題的直接答案,而是比通過小評論框嘗試以下更清晰的后續查詢]
作為快速檢查,將迭代器計數器傳遞給 spec_trans 並將其返回(以及您的結果) - 並將其推送到單獨的列表中,transformXseq 或其他東西 - 然后與您的輸入序列進行比較。 IE
def spec_trans(d,wav_fam,threshold_val,thresh_type, iCount):
data=np.array(d,dtype=np.float64)
data_dec=decomposition(data,wav_fam)
data_t=thresholding(data_dec,threshold_val,thresh_type)
data_rec=reconstruction(data_t,wav_fam)
return data_rec, iCount
然后在 main
jobs=[]
iJobs = 0
for dataBand in xmp:
jobs.append(p.apply_async(spec_trans,args=(dataBand,wav_fam,threshold_val,thresh_type, iJobs)))
iJobs = iJobs + 1
transformedX=[]
transformedXseq=[]
for jobBit in jobs:
res = jobBit.get()
transformedX.append(res[0])
transformedXseq.append(res[1])
...並檢查列表 transformXseq 以查看您是否按照提交的順序收集了作業。 它應該匹配!
假設wav_fam
、 threshold_val
和thresh_type
不會因調用而異,首先將這些 arguments 安排為第一個arguments 到工人spec_trans
:
def spec_trans(wav_fam, threshold_val, thresh_type, d):
現在我看不到您在池創建塊中的哪個位置定義xmp
,但大概這是一個可迭代的。 您需要按如下方式修改此代碼:
from functools import partial
def compute_chunksize(pool_size, iterable_size):
chunksize, remainder = divmod(iterable_size, 4 * pool_size)
if remainder:
chunksize += 1
return chunksize
if __name__ == '__main__':
X=tifffile.imread('data/Classification/university.tif')
#take paramaters
threshold_val=float(input("Enter the value for image thresholding: "))
print("The available wavelet functions:",pywt.wavelist())
wav_fam=input("Choose a wavelet function for transformation: ")
threshold_type=['hard','soft']
print("The available wavelet functions:",threshold_type)
thresh_type=input("Choose a type for threshholding technique: ")
start=time.time()
p = mp.Pool(4)
# first 3 arguments to spec_trans will be wav_fam, threshold_val and thresh_type
worker = partial(spec_trans, wav_fam, threshold_val, thresh_type)
suitable_chunksize = compute_chunksize(4, len(xmp))
transformedX = list(p.imap(worker, xmp, chunksize=suitable_chunksize))
end=time.time()
要獲得比使用apply_async
更高的性能,您必須使用imap
的“合適的塊大小”值。 Function compute_chunksize
可用於根據池的大小(即 4)和傳遞給imap
的可迭代的大小(即len(xmp)
來計算這樣的值。 如果xmp
的大小足夠小,以至於計算出的 chunksize 值為 1,我真的看不出imap
會比apply_async
性能顯着提高。
當然,您也可以只使用:
transformedX = p.map(worker, xmp)
並讓池計算自己合適的塊大小。 當可迭代對象非常大並且還不是列表時, imap
比map
具有優勢。 對於map
來計算合適的塊大小,它首先必須將可迭代對象轉換為列表以獲取其長度,這可能是 memory 效率低下。 但是如果您知道可迭代對象的長度(或近似長度),那么通過使用 imap 您可以顯式設置塊大小,而無需將可迭代對象轉換為列表。 與imap_unordered
相比, map
的另一個優勢是,您可以在單個任務可用時處理它們的結果,而使用map
,您只有在所有提交的任務完成后才能獲得結果。
更新
如果您想捕獲提交給您的工作人員 function 的單個任務可能引發的異常,請堅持使用imap
,並使用以下代碼迭代imap
返回的結果:
#transformedX = list(p.imap(worker, xmp, chunksize=suitable_chunksize))
transformedX = []
results = p.imap(worker, xmp, chunksize=suitable_chunksize)
import traceback
while True:
try:
result = next(results)
except StopIteration: # no more results
break
except Exception as e:
print('Exception occurred:', e)
traceback.print_exc() # print stacktrace
else:
transformedX.append(result)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.