繁体   English   中英

tf.gradients是线程安全的吗?

[英]Is tf.gradients thread-safe?

我有多个tf.gradients调用,每个调用都需要一些时间,因此我想同时调用tf.gradients 但是,当我尝试在图形中这样做时,会收到几个错误之一。 我怀疑它不是线程安全的,但无法通过MWE重现该错误。 我尝试在我的MWE和实际代码中同时使用pathos.pools.ThreadPoolpathos.pools.ProcessPool只有我的真实代码失败。 这是我尝试过的MWE:

from pathos.pools import ThreadPool, ProcessPool
import tensorflow as tf
import numpy as np

Xs = [tf.cast(np.random.random((10,10)), dtype=tf.float64) for i in range(3)]
Ys = [Xs[0]*Xs[1]*Xs[2], Xs[0]/Xs[1]*Xs[2], Xs[0]/Xs[1]/Xs[2]]

def compute_grad(YX):
    return tf.gradients(YX[0], YX[1])

tp = ThreadPool(3)
res = tp.map(compute_grad, zip(Ys, Xs))
print(res)

这是我尝试实际代码时遇到的部分回溯。 这是ThreadPool版本。

File "pathos/threading.py", line 134, in map
    return _pool.map(star(f), zip(*args)) # chunksize
  File "multiprocess/pool.py", line 260, in map
    return self._map_async(func, iterable, mapstar, chunksize).get()
  File "multiprocess/pool.py", line 608, in get
    raise self._value
  File "multiprocess/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "multiprocess/pool.py", line 44, in mapstar
    return list(map(*args))
  File "pathos/helpers/mp_helper.py", line 15, in <lambda>
    func = lambda args: f(*args)
  File "my_code.py", line 939, in gradients_with_index
    return (tf.gradients(Y, variables), b_idx)
  File "tensorflow/python/ops/gradients_impl.py", line 448, in gradients
    colocate_gradients_with_ops)
  File "tensorflow/python/ops/gradients_impl.py", line 188, in _PendingCount
    between_op_list, between_ops, colocate_gradients_with_ops)
  File "tensorflow/python/ops/control_flow_ops.py", line 1288, in MaybeCreateControlFlowState
    loop_state.AddWhileContext(op, between_op_list, between_ops)
  File "tensorflow/python/ops/control_flow_ops.py", line 1103, in AddWhileContext
    grad_state = GradLoopState(forward_ctxt, outer_grad_state)
  File "tensorflow/python/ops/control_flow_ops.py", line 737, in __init__
    cnt, outer_grad_state)
  File "tensorflow/python/ops/control_flow_ops.py", line 2282, in AddBackPropLoopCounter
    merge_count = merge([enter_count, enter_count])[0]
  File "tensorflow/python/ops/control_flow_ops.py", line 404, in merge
    return gen_control_flow_ops._merge(inputs, name)
  File "tensorflow/python/ops/gen_control_flow_ops.py", line 150, in _merge
    result = _op_def_lib.apply_op("Merge", inputs=inputs, name=name)
  File "tensorflow/python/framework/op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "tensorflow/python/framework/ops.py", line 2506, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "tensorflow/python/framework/ops.py", line 1273, in __init__
    self._control_flow_context.AddOp(self)
  File "tensorflow/python/ops/control_flow_ops.py", line 2147, in AddOp
    self._AddOpInternal(op)
  File "tensorflow/python/ops/control_flow_ops.py", line 2177, in _AddOpInternal
    self._MaybeAddControlDependency(op)
  File "tensorflow/python/ops/control_flow_ops.py", line 2204, in _MaybeAddControlDependency
    op._add_control_input(self.GetControlPivot().op)
AttributeError: 'NoneType' object has no attribute 'op'

这是另一个回溯。 注意错误是不同的

Traceback (most recent call last):
  File "tensorflow/python/ops/control_flow_ops.py", line 869, in AddForwardAccumulator
    enter_acc = self.forward_context.AddValue(acc)
  File "tensorflow/python/ops/control_flow_ops.py", line 2115, in AddValue
    self._outer_context.AddInnerOp(enter.op)
  File "tensorflow/python/framework/ops.py", line 3355, in __exit__
    self._graph._pop_control_dependencies_controller(self)
  File "tensorflow/python/framework/ops.py", line 3375, in _pop_control_dependencies_controller
    assert self._control_dependencies_stack[-1] is controller
AssertionError

ProcessPool版本遇到错误:

_pickle.PicklingError: Can't pickle <class 'tensorflow.python.util.tf_should_use._add_should_use_warning.<locals>.TFShouldUseWarningWrapper'>: it's not found as tensorflow.python.util.tf_should_use._add_should_use_warning.<locals>.TFShouldUseWarningWrapper

tf.gradients()函数不是线程安全的 它会对图形进行一系列复杂且非原子的修改,并且这些修改不受锁的保护。 特别是,似乎在包含控制流操作(例如tf.while_loop() )的图上使用tf.gradients()可能会在您同时运行时出现问题。

请注意,即使以线程安全的方式实现对tf.gradients()并行调用也不可能加快它的速度。 该函数不执行任何I / O操作,也不调用释放Python的GIL的任何本机方法,因此执行很可能会被序列化。 实现基于multiprocessing的并行性将需要额外的系统调用来访问共享图(以及获取/释放锁),因此这不太可能更快。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM