[英]Cannot pickle 'weakref' object when using Tensorflow with Multiprocessing
我想同时训练几个神经网络,我正在尝试使用multiprocessing
模块,以便可以在单独的过程中训练每个网络,但我遇到了一个问题。 当我运行下面的演示代码时(由于apply_async
function 没有给出错误提示,我暂时将其更改为apply
功能):
import tensorflow as tf
import multiprocessing as mp
class SeqModel(tf.keras.Sequential):
def __init__(self, input_size, hidden_sizes, output_size):
super().__init__()
self.add(tf.keras.layers.Dense(hidden_sizes[0], activation="relu", input_shape=(input_size,)))
for hidden_size in hidden_sizes[1:]: self.add(tf.keras.layers.Dense(hidden_size, activation="relu"))
if output_size is not None: self.add(tf.keras.layers.Dense(output_size))
class Partition:
def __init__(self, partition_id):
self.partition_id = partition_id
self.model = None
def initialization(self):
self.model = SeqModel(10,[10,10],10)
def test(self):
print(f'partition {self.partition_id} testing...')
def func():
partition_list = [Partition(i) for i in range(4)]
for partition in partition_list: partition.initialization()
p = mp.Pool(4)
for partition in partition_list:
p.apply(partition.test)
p.close()
p.join()
if __name__ == '__main__':
func()
我收到以下错误:
Traceback (most recent call last):
File "C:/Users/Administrator/Dropbox (ASU)/Work/Traffic State Estimation/traffic state estimation/dataset/mp/mp_net.py", line 43, in <module>
func()
File "C:/Users/Administrator/Dropbox (ASU)/Work/Traffic State Estimation/traffic state estimation/dataset/mp/mp_net.py", line 37, in func
p.apply(partition.test)
File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\pool.py", line 357, in apply
return self.apply_async(func, args, kwds).get()
File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\pool.py", line 771, in get
raise self._value
File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\pool.py", line 537, in _handle_tasks
put(task)
File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\connection.py", line 206, in send
self._send_bytes(_ForkingPickler.dumps(obj))
File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
TypeError: cannot pickle 'weakref' object
如果我不进行分区初始化(分区实例中不涉及 SeqModel),则代码运行没有问题。 这是否意味着我不能在子进程中使用 tf 模型?
要使用Pool
,您的对象必须是可挑选的,因为Pool
方法使用mp.SimpleQueue
将任务发送到进程,而mp.SimpleQueue
只接受挑选的对象。
Tensorflow 模型默认情况下不可选择,因此您不能轻松地将池与 Tensorflow 模型一起使用。 请参阅Model
TensorFlow
选择。
但是,您可以尝试通过讨论https://github.com/tensorflow/tensorflow/issues/34697#issuecomment-627193883中建议的解决方法,使Model
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.