繁体   English   中英

Tensflow Keras: TypeError: can't pickle _thread.RLock objects when using multiprocessing

[英]Tensflow Keras: TypeError: can't pickle _thread.RLock objects when using multiprocessing

我在 GitHub 中提出了这个问题: https://github.com/tensorflow/tensorflow/issues/46917

我正在尝试使用多处理线程来加速我的一些代码。 其中我必须向每个线程发送一个 Keras model 并使用它来预测一些输入并进行一些以下计算。 但是,我最终遇到以下错误

Tensflow Keras: TypeError: can't pickle _thread.RLock objects

我试过了,

  1. 使用partial来修复 model 参数并使用生成的偏函数。
  2. 克隆 model 并为每个线程使用克隆
  3. 为每个线程保存和重新加载 model
  4. 尝试使用pathos.multiprocessing但它们都不起作用。

以下是MWE

import tensorflow as tf
from tensorflow import keras
import numpy as np


from multiprocessing import Pool
# from multiprocessing.dummy import Pool as ThreadPool
# from pathos.multiprocessing import ProcessingPool as Pool
from functools import partial


def simple_model():
    model = keras.models.Sequential([
        keras.layers.Dense(units = 10, input_shape = [1]),
        keras.layers.Dense(units = 1, activation = 'sigmoid')
    ])
    model.compile(optimizer = 'sgd', loss = 'mean_squared_error')
    return model

def clone_model(model):
    model_clone = tf.keras.models.clone_model(model)
    model_clone.set_weights(model.get_weights())
    model_clone.build((None, 1))
    model_clone.compile(optimizer = 'sgd', loss = 'mean_squared_error')
    return model_clone

def work(model, seq):
    return model.predict(seq)

def load_model(model_savepath):
    return tf.keras.models.load_model(model_savepath)

def worker(model, n = 4):
    seqences = np.arange(0,100).reshape(n, -1)
    pool = Pool()
    model_savepath = './simple_model.h5'
    model.save(model_savepath)
    model_list = [load_model(model_savepath) for _ in range(n)]
    # model_list = [clone_model(model) for _ in range(n)]
    results = pool.map(work, zip(model_list,seqences))
    # partial_work = partial(work, model=model)
    # results = pool.map(partial_work, seqences)
    pool.close()
    pool.join()
    
    return np.reshape(results, (-1, ))



if __name__ == '__main__':

    model = simple_model()
    out = worker(model, n=4)
    print(out)

这会导致以下错误跟踪:

File "c:/Users/***/Documents/GitHub/COVID-NSF/test4.py", line 42, in <module>
  out = worker(model, n=4)
File "c:/Users/****/Documents/GitHub/COVID-NSF/test4.py", line 30, in worker
  results = pool.map(work, zip(model_list,seqences))
File "C:\Users\****\anaconda3\envs\tf-gpu\lib\multiprocessing\pool.py", line 268, in map
  return self._map_async(func, iterable, mapstar, chunksize).get()
File "C:\Users\****\anaconda3\envs\tf-gpu\lib\multiprocessing\pool.py", line 657, in get
  raise self._value
File "C:\Users\***\anaconda3\envs\tf-gpu\lib\multiprocessing\pool.py", line 431, in _handle_tasks
  put(task)
File "C:\Users\***\anaconda3\envs\tf-gpu\lib\multiprocessing\connection.py", line 206, in send
  self._send_bytes(_ForkingPickler.dumps(obj))
File "C:\Users\***\anaconda3\envs\tf-gpu\lib\multiprocessing\reduction.py", line 51, in dumps
  cls(buf, protocol).dump(obj)
TypeError: can't pickle _thread.RLock objects

@Aaron 感谢您解释 amahendrakar 对 GitHub 的评论。 我修改了代码,使代码将 model 的路径而不是 model 本身发送到子进程。 下面是工作代码

import tensorflow as tf
from tensorflow import keras
import numpy as np


# from multiprocessing import Pool
from multiprocessing.dummy import Pool as ThreadPool
from pathos.multiprocessing import ProcessingPool as Pool
from functools import partial
import time

def simple_model():
    model = keras.models.Sequential([
        keras.layers.Dense(units = 10, input_shape = [1]),
        keras.layers.Dense(units = 1, activation = 'sigmoid')
    ])
    model.compile(optimizer = 'sgd', loss = 'mean_squared_error')
    return model

def clone_model(model):
    model_clone = tf.keras.models.clone_model(model)
    model_clone.set_weights(model.get_weights())
    model_clone.build((None, 1))
    model_clone.compile(optimizer = 'sgd', loss = 'mean_squared_error')
    return model_clone

def work(model, seq):
    return model.predict(seq)

def work_new(seq):
    model_savepath = './simple_model.h5'
    model = tf.keras.models.load_model(model_savepath)
    return model.predict(seq)

def load_model(model_savepath):
    return tf.keras.models.load_model(model_savepath)

def worker(model, n = 4):
    seqences = np.arange(0,10*n).reshape(n, -1)
    pool = Pool()
    model_savepath = './simple_model.h5'
    model.save(model_savepath)
    # model_list = [load_model(model_savepath) for _ in range(n)]
    # model_list = [clone_model(model) for _ in range(n)]
    # results = pool.map(work, zip(model_list,seqences))
    # path_list = [[model_savepath] for _ in range(n)]
    # print(np.shape(path_list), np.shape(seqences))
    # work_new_partial = partial(work_new, path=model_savepath)
    results = pool.map(work_new,  seqences)
    # partial_work = partial(work, model=model)
    # results = pool.map(partial_work, seqences)
    pool.close()
    pool.join()
    # print(t1-t0)
    return np.reshape(results, (-1, ))



if __name__ == '__main__':

    model = simple_model()
    t0 = time.perf_counter()
    out = worker(model, n=40)
    t1 = time.perf_counter()

    # print(out)
    print(f"time taken {t1 - t0}")

这导致

time taken 8.521342800000001

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM