繁体   English   中英

将 Tensorflow 与多处理一起使用时,无法腌制“weakref”object

[英]Cannot pickle 'weakref' object when using Tensorflow with Multiprocessing

我想同时训练几个神经网络,我正在尝试使用multiprocessing模块,以便可以在单独的过程中训练每个网络,但我遇到了一个问题。 当我运行下面的演示代码时(由于apply_async function 没有给出错误提示,我暂时将其更改为apply功能):

import tensorflow as tf
import multiprocessing as mp


class SeqModel(tf.keras.Sequential):
    def __init__(self, input_size, hidden_sizes, output_size):
        super().__init__()
        self.add(tf.keras.layers.Dense(hidden_sizes[0], activation="relu", input_shape=(input_size,)))
        for hidden_size in hidden_sizes[1:]: self.add(tf.keras.layers.Dense(hidden_size, activation="relu"))
        if output_size is not None: self.add(tf.keras.layers.Dense(output_size))


class Partition:
    def __init__(self, partition_id):
        self.partition_id = partition_id
        self.model = None

    def initialization(self):
        self.model = SeqModel(10,[10,10],10)

    def test(self):
        print(f'partition {self.partition_id} testing...')


def func():
    partition_list = [Partition(i) for i in range(4)]

    for partition in partition_list: partition.initialization()

    p = mp.Pool(4)
    for partition in partition_list:
        p.apply(partition.test)
    p.close()
    p.join()


if __name__ == '__main__':
    func()

我收到以下错误:

Traceback (most recent call last):
  File "C:/Users/Administrator/Dropbox (ASU)/Work/Traffic State Estimation/traffic state estimation/dataset/mp/mp_net.py", line 43, in <module>
    func()
  File "C:/Users/Administrator/Dropbox (ASU)/Work/Traffic State Estimation/traffic state estimation/dataset/mp/mp_net.py", line 37, in func
    p.apply(partition.test)
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\pool.py", line 357, in apply
    return self.apply_async(func, args, kwds).get()
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\pool.py", line 771, in get
    raise self._value
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\pool.py", line 537, in _handle_tasks
    put(task)
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\connection.py", line 206, in send
    self._send_bytes(_ForkingPickler.dumps(obj))
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
TypeError: cannot pickle 'weakref' object

如果我不进行分区初始化(分区实例中不涉及 SeqModel),则代码运行没有问题。 这是否意味着我不能在子进程中使用 tf 模型?

要使用Pool ,您的对象必须是可挑选的,因为Pool方法使用mp.SimpleQueue将任务发送到进程,而mp.SimpleQueue只接受挑选的对象。

Tensorflow 模型默认情况下不可选择,因此您不能轻松地将池与 Tensorflow 模型一起使用。 请参阅Model TensorFlow选择。

但是,您可以尝试通过讨论https://github.com/tensorflow/tensorflow/issues/34697#issuecomment-627193883中建议的解决方法,使Model

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM