简体   繁体   中英

Tensflow Keras: TypeError: can't pickle _thread.RLock objects when using multiprocessing

I have raised this issue in GitHub: https://github.com/tensorflow/tensorflow/issues/46917

I am trying to use multiprocessing threads to speedup the some of my code. In which I have to send a Keras model to each thread and use it to predict on some inputs and do some following computations. However, I end up with the following error

Tensflow Keras: TypeError: can't pickle _thread.RLock objects

I tried,

  1. using partial to fix the model argument and use the resulting partial-function.
  2. cloning the model and using a clone for each thread
  3. saving and reloading a model for each thread
  4. tried using pathos.multiprocessing but none of them worked.

The following is the MWE

import tensorflow as tf
from tensorflow import keras
import numpy as np


from multiprocessing import Pool
# from multiprocessing.dummy import Pool as ThreadPool
# from pathos.multiprocessing import ProcessingPool as Pool
from functools import partial


def simple_model():
    model = keras.models.Sequential([
        keras.layers.Dense(units = 10, input_shape = [1]),
        keras.layers.Dense(units = 1, activation = 'sigmoid')
    ])
    model.compile(optimizer = 'sgd', loss = 'mean_squared_error')
    return model

def clone_model(model):
    model_clone = tf.keras.models.clone_model(model)
    model_clone.set_weights(model.get_weights())
    model_clone.build((None, 1))
    model_clone.compile(optimizer = 'sgd', loss = 'mean_squared_error')
    return model_clone

def work(model, seq):
    return model.predict(seq)

def load_model(model_savepath):
    return tf.keras.models.load_model(model_savepath)

def worker(model, n = 4):
    seqences = np.arange(0,100).reshape(n, -1)
    pool = Pool()
    model_savepath = './simple_model.h5'
    model.save(model_savepath)
    model_list = [load_model(model_savepath) for _ in range(n)]
    # model_list = [clone_model(model) for _ in range(n)]
    results = pool.map(work, zip(model_list,seqences))
    # partial_work = partial(work, model=model)
    # results = pool.map(partial_work, seqences)
    pool.close()
    pool.join()
    
    return np.reshape(results, (-1, ))



if __name__ == '__main__':

    model = simple_model()
    out = worker(model, n=4)
    print(out)

This results in the following error trace:

File "c:/Users/***/Documents/GitHub/COVID-NSF/test4.py", line 42, in <module>
  out = worker(model, n=4)
File "c:/Users/****/Documents/GitHub/COVID-NSF/test4.py", line 30, in worker
  results = pool.map(work, zip(model_list,seqences))
File "C:\Users\****\anaconda3\envs\tf-gpu\lib\multiprocessing\pool.py", line 268, in map
  return self._map_async(func, iterable, mapstar, chunksize).get()
File "C:\Users\****\anaconda3\envs\tf-gpu\lib\multiprocessing\pool.py", line 657, in get
  raise self._value
File "C:\Users\***\anaconda3\envs\tf-gpu\lib\multiprocessing\pool.py", line 431, in _handle_tasks
  put(task)
File "C:\Users\***\anaconda3\envs\tf-gpu\lib\multiprocessing\connection.py", line 206, in send
  self._send_bytes(_ForkingPickler.dumps(obj))
File "C:\Users\***\anaconda3\envs\tf-gpu\lib\multiprocessing\reduction.py", line 51, in dumps
  cls(buf, protocol).dump(obj)
TypeError: can't pickle _thread.RLock objects

@Aaron thanks for explaining the comment made by amahendrakar on GitHub. I modified the code such that the code sends the path of the model, rather than the model itself, to the child processes. Below is the working code

import tensorflow as tf
from tensorflow import keras
import numpy as np


# from multiprocessing import Pool
from multiprocessing.dummy import Pool as ThreadPool
from pathos.multiprocessing import ProcessingPool as Pool
from functools import partial
import time

def simple_model():
    model = keras.models.Sequential([
        keras.layers.Dense(units = 10, input_shape = [1]),
        keras.layers.Dense(units = 1, activation = 'sigmoid')
    ])
    model.compile(optimizer = 'sgd', loss = 'mean_squared_error')
    return model

def clone_model(model):
    model_clone = tf.keras.models.clone_model(model)
    model_clone.set_weights(model.get_weights())
    model_clone.build((None, 1))
    model_clone.compile(optimizer = 'sgd', loss = 'mean_squared_error')
    return model_clone

def work(model, seq):
    return model.predict(seq)

def work_new(seq):
    model_savepath = './simple_model.h5'
    model = tf.keras.models.load_model(model_savepath)
    return model.predict(seq)

def load_model(model_savepath):
    return tf.keras.models.load_model(model_savepath)

def worker(model, n = 4):
    seqences = np.arange(0,10*n).reshape(n, -1)
    pool = Pool()
    model_savepath = './simple_model.h5'
    model.save(model_savepath)
    # model_list = [load_model(model_savepath) for _ in range(n)]
    # model_list = [clone_model(model) for _ in range(n)]
    # results = pool.map(work, zip(model_list,seqences))
    # path_list = [[model_savepath] for _ in range(n)]
    # print(np.shape(path_list), np.shape(seqences))
    # work_new_partial = partial(work_new, path=model_savepath)
    results = pool.map(work_new,  seqences)
    # partial_work = partial(work, model=model)
    # results = pool.map(partial_work, seqences)
    pool.close()
    pool.join()
    # print(t1-t0)
    return np.reshape(results, (-1, ))



if __name__ == '__main__':

    model = simple_model()
    t0 = time.perf_counter()
    out = worker(model, n=40)
    t1 = time.perf_counter()

    # print(out)
    print(f"time taken {t1 - t0}")

This results in

time taken 8.521342800000001

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM