简体   繁体   中英

Tensorflow Error using Pool: can't pickle _thread.RLock objects

I was trying to implement a massively parallel Differential Equation solver (30k DEs) on Tensorflow CPU but was running out of memory (Around 30GB matrices). So I implemented a batch based solver (solve for small time and save data -> set new initial -> solve again). But the problem persisted. I learnt that Tensorflow does not clear the memory until the python interpreter is closed. So based on info on github issues I tried implementing a multiprocessing solution using pool but I keep getting a "can't pickle _thread.RLock objects" at the Pooling step. Could someone please help!

def dAdt(X,t):
  dX = // vector of differential
  return dX

global state_vector
global state

state_vector =  [0]*n // initial state

def tensor_process():
    with tf.Session() as sess:
        print("Session started...",end="")
        tf.global_variables_initializer().run()
        state = sess.run(tensor_state)
        sess.close()


n_batch = 3
t_batch = np.array_split(t,n_batch)


for n,i in enumerate(t_batch):
    print("Batch",(n+1),"Running...",end="")
    if n>0:
        i = np.append(i[0]-0.01,i)
    print("Session started...",end="")
    init_state = tf.constant(state_vector, dtype=tf.float64)
    tensor_state = tf.contrib.odeint_fixed(dAdt, init_state, i)
    with Pool(1) as p:
        p.apply_async(tensor_process).get()
    state_vector = state[-1,:]
    np.save("state.batch"+str(n+1),state)
    state=None

Tensorflow doesn't support multiprocessing due to many reasons like it not able to fork the TensorFlow session itself. If you still want to use some kind of 'multi' stuff, try this (multiprocessing.pool.ThreadPool) which worked for me:

https://stackoverflow.com/a/46049195/5276428

Note: I did this by creating multiple sessions over threads and then calling each session variables belonging to each thread sequentially. If your issue is memory, I think it can be solved by reducing input batch-size.

Rather than use a Pool of N workers, try creating N distinct instances of multiprocessing.Process objects and passing your tensor_process() function as the target argument and each subset of data as the args arguments. Start the processes inside a for-loop, then join them beneath the loop. You can use a shared multiprocessing.Queue object to return results to the main process.

I have personally had success combining TensorFlow with Python's multiprocessing module by sub-classing Process and overriding its run() method .

def run(self):
  logging.info('started inference.')
  logging.debug('TF input frame shape == {}'.format(self.tensor_shape))

  count = 0

  with tf.device('/cpu:0') if self.device_type == 'cpu' else \
      tf.device(None):
    with tf.Session(config=self.session_config) as session:
      frame_dataset = tf.data.Dataset.from_generator(
        self.generate_frames, tf.uint8, tf.TensorShape(self.tensor_shape))
      frame_dataset = frame_dataset.map(self._preprocess_frames,
                                        self._get_num_parallel_calls())
      frame_dataset = frame_dataset.batch(self.batch_size)
      frame_dataset = frame_dataset.prefetch(self.batch_size)
      next_batch = frame_dataset.make_one_shot_iterator().get_next()

      while True:
        try:
          frame_batch = session.run(next_batch)
          probs = session.run(self.output_node,
                              {self.input_node: frame_batch})
          self.prob_array[count:count + probs.shape[0]] = probs
          count += probs.shape[0]
        except tf.errors.OutOfRangeError:
          logging.info('completed inference.')
          break

  self.result_queue.put((count, self.prob_array, self.timestamp_array))
  self.result_queue.close()

I would write an example based on your code, but I don't quite understand it.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM