[英]Tensorflow Error using Pool: can't pickle _thread.RLock objects
I was trying to implement a massively parallel Differential Equation solver (30k DEs) on Tensorflow CPU but was running out of memory (Around 30GB matrices). 我试图在Tensorflow CPU上实现大规模并行微分方程求解器(30k DE),但内存不足(大约30GB矩阵)。 So I implemented a batch based solver (solve for small time and save data -> set new initial -> solve again).
因此,我实现了一个基于批处理的求解器(求解时间短,并保存数据->设置新的初始值->再次求解)。 But the problem persisted.
但是问题仍然存在。 I learnt that Tensorflow does not clear the memory until the python interpreter is closed.
我了解到,在关闭python解释器之前,Tensorflow不会清除内存。 So based on info on github issues I tried implementing a multiprocessing solution using pool but I keep getting a "can't pickle _thread.RLock objects" at the Pooling step.
因此,基于有关github问题的信息,我尝试使用池来实现多处理解决方案,但在池化步骤中,我始终收到“无法腌制_thread.RLock对象”的信息。 Could someone please help!
有人可以帮忙吗!
def dAdt(X,t):
dX = // vector of differential
return dX
global state_vector
global state
state_vector = [0]*n // initial state
def tensor_process():
with tf.Session() as sess:
print("Session started...",end="")
tf.global_variables_initializer().run()
state = sess.run(tensor_state)
sess.close()
n_batch = 3
t_batch = np.array_split(t,n_batch)
for n,i in enumerate(t_batch):
print("Batch",(n+1),"Running...",end="")
if n>0:
i = np.append(i[0]-0.01,i)
print("Session started...",end="")
init_state = tf.constant(state_vector, dtype=tf.float64)
tensor_state = tf.contrib.odeint_fixed(dAdt, init_state, i)
with Pool(1) as p:
p.apply_async(tensor_process).get()
state_vector = state[-1,:]
np.save("state.batch"+str(n+1),state)
state=None
Tensorflow doesn't support multiprocessing due to many reasons like it not able to fork the TensorFlow session itself. Tensorflow不支持多重处理,原因有很多,例如它无法派生TensorFlow会话本身。 If you still want to use some kind of 'multi' stuff, try this (multiprocessing.pool.ThreadPool) which worked for me:
如果您仍然想使用某种“多”东西,请尝试对我有用的这个(multiprocessing.pool.ThreadPool):
https://stackoverflow.com/a/46049195/5276428 https://stackoverflow.com/a/46049195/5276428
Note: I did this by creating multiple sessions over threads and then calling each session variables belonging to each thread sequentially. 注意:我是通过在线程上创建多个会话,然后依次调用属于每个线程的每个会话变量来实现的。 If your issue is memory, I think it can be solved by reducing input batch-size.
如果您的问题是内存,我认为可以通过减少输入批处理大小来解决。
Rather than use a Pool of N workers, try creating N distinct instances of multiprocessing.Process objects and passing your tensor_process() function as the target argument and each subset of data as the args arguments. 不要使用N个工作池,而是尝试创建N个不同的multiprocessing.Process对象实例,并将tensor_process()函数作为目标参数,并将每个数据子集作为args参数。 Start the processes inside a for-loop, then join them beneath the loop.
在for循环内启动进程,然后将其加入循环下。 You can use a shared multiprocessing.Queue object to return results to the main process.
您可以使用共享的multiprocessing.Queue对象将结果返回到主流程。
I have personally had success combining TensorFlow with Python's multiprocessing module by sub-classing Process and overriding its run() method . 我个人通过将Process子类化并覆盖其run()方法,成功地将TensorFlow与Python的多处理模块结合在一起。
def run(self):
logging.info('started inference.')
logging.debug('TF input frame shape == {}'.format(self.tensor_shape))
count = 0
with tf.device('/cpu:0') if self.device_type == 'cpu' else \
tf.device(None):
with tf.Session(config=self.session_config) as session:
frame_dataset = tf.data.Dataset.from_generator(
self.generate_frames, tf.uint8, tf.TensorShape(self.tensor_shape))
frame_dataset = frame_dataset.map(self._preprocess_frames,
self._get_num_parallel_calls())
frame_dataset = frame_dataset.batch(self.batch_size)
frame_dataset = frame_dataset.prefetch(self.batch_size)
next_batch = frame_dataset.make_one_shot_iterator().get_next()
while True:
try:
frame_batch = session.run(next_batch)
probs = session.run(self.output_node,
{self.input_node: frame_batch})
self.prob_array[count:count + probs.shape[0]] = probs
count += probs.shape[0]
except tf.errors.OutOfRangeError:
logging.info('completed inference.')
break
self.result_queue.put((count, self.prob_array, self.timestamp_array))
self.result_queue.close()
I would write an example based on your code, but I don't quite understand it. 我会根据您的代码编写一个示例,但我不太理解。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.