[英]multiprocessing with Xarray and Numpy array
So I am trying to implement a solution that was already described here , but I am changing it a bit.因此,我正在尝试实施已在此处描述的解决方案,但我正在对其进行一些更改。 Instead of just trying to change the array with operations, I am trying to read from a NetCDF file using xarray and then write to a shared numpy array with the multiprocessing module.
我不只是尝试通过操作更改数组,而是尝试使用 xarray 从 NetCDF 文件中读取,然后使用多处理模块写入共享的 numpy 数组。
I feel as though I am getting pretty close, but something is going wrong.我觉得我已经很接近了,但出了点问题。 I have pasted a reproducible, easy copy/paste example below.
我在下面粘贴了一个可重复的、简单的复制/粘贴示例。 As you can see, when I run the processes, they can all read the files that I created, but they do not correctly update the shared numpy array that I am trying to write to.
如您所见,当我运行这些进程时,它们都可以读取我创建的文件,但它们没有正确更新我尝试写入的共享 numpy 数组。 Any help would be appreciated.
任何帮助,将不胜感激。
Code代码
import ctypes
import logging
import multiprocessing as mp
import xarray as xr
from contextlib import closing
import numpy as np
info = mp.get_logger().info
def main():
data = np.arange(10)
for i in range(4):
ds = xr.Dataset({'x': data})
ds.to_netcdf('test_{}.nc'.format(i))
ds.close()
logger = mp.log_to_stderr()
logger.setLevel(logging.INFO)
# create shared array
N, M = 4, 10
shared_arr = mp.Array(ctypes.c_float, N * M)
arr = tonumpyarray(shared_arr, dtype=np.float32)
arr = arr.reshape((N, M))
# Fill with random values
arr[:, :] = np.zeros((N, M))
arr_orig = arr.copy()
files = ['test_0.nc', 'test_1.nc', 'test_2.nc', 'test_3.nc']
parameter_tuples = [
(files[0], 0),
(files[1], 1),
(files[2], 2),
(files[3], 3)
]
# write to arr from different processes
with closing(mp.Pool(initializer=init, initargs=(shared_arr,))) as p:
# many processes access different slices of the same array
p.map_async(g, parameter_tuples)
p.join()
print(arr_orig)
print(tonumpyarray(shared_arr, np.float32).reshape(N, M))
def init(shared_arr_):
global shared_arr
shared_arr = shared_arr_ # must be inherited, not passed as an argument
def tonumpyarray(mp_arr, dtype=np.float64):
return np.frombuffer(mp_arr.get_obj(), dtype)
def g(params):
"""no synchronization."""
print("Current File Name: ", params[0])
tmp_dataset = xr.open_dataset(params[0])
print(tmp_dataset["x"].data[:])
arr = tonumpyarray(shared_arr)
arr[params[1], :] = tmp_dataset["x"].data[:]
tmp_dataset.close()
if __name__ == '__main__':
mp.freeze_support()
main()
1.You forgot to reshape back after tonumpyarray
. 1.你忘记在
tonumpyarray
之后重新tonumpyarray
。
2.You used the wrong dtype
in tonumpyarray
. 2.You使用了错误的
dtype
在tonumpyarray
。
import ctypes
import logging
import multiprocessing as mp
import xarray as xr
from contextlib import closing
import numpy as np
info = mp.get_logger().info
def main():
data = np.arange(10)
for i in range(4):
ds = xr.Dataset({'x': data})
ds.to_netcdf('test_{}.nc'.format(i))
ds.close()
logger = mp.log_to_stderr()
logger.setLevel(logging.INFO)
# create shared array
N, M = 4, 10
shared_arr = mp.Array(ctypes.c_float, N * M)
arr = tonumpyarray(shared_arr, dtype=np.float32)
arr = arr.reshape((N, M))
# Fill with random values
arr[:, :] = np.zeros((N, M))
arr_orig = arr.copy()
files = ['test_0.nc', 'test_1.nc', 'test_2.nc', 'test_3.nc']
parameter_tuples = [
(files[0], 0),
(files[1], 1),
(files[2], 2),
(files[3], 3)
]
# write to arr from different processes
with closing(mp.Pool(initializer=init, initargs=(shared_arr, N, M))) as p:
# many processes access different slices of the same array
p.map_async(g, parameter_tuples)
p.join()
print(arr_orig)
print(tonumpyarray(shared_arr, np.float32).reshape(N, M))
def init(shared_arr_, N_, M_): # add shape
global shared_arr
global N, M
shared_arr = shared_arr_ # must be inherited, not passed as an argument
N = N_
M = M_
def tonumpyarray(mp_arr, dtype=np.float32): # change type
return np.frombuffer(mp_arr.get_obj(), dtype)
def g(params):
"""no synchronization."""
print("Current File Name: ", params[0])
tmp_dataset = xr.open_dataset(params[0])
print(tmp_dataset["x"].data[:])
arr = tonumpyarray(shared_arr).reshape(N, M) # reshape
arr[params[1], :] = tmp_dataset["x"].data[:]
tmp_dataset.close()
if __name__ == '__main__':
mp.freeze_support()
main()
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.