簡體   English   中英

tf.gradients是線程安全的嗎?

[英]Is tf.gradients thread-safe?

我有多個tf.gradients調用,每個調用都需要一些時間,因此我想同時調用tf.gradients 但是,當我嘗試在圖形中這樣做時,會收到幾個錯誤之一。 我懷疑它不是線程安全的,但無法通過MWE重現該錯誤。 我嘗試在我的MWE和實際代碼中同時使用pathos.pools.ThreadPoolpathos.pools.ProcessPool只有我的真實代碼失敗。 這是我嘗試過的MWE:

from pathos.pools import ThreadPool, ProcessPool
import tensorflow as tf
import numpy as np

Xs = [tf.cast(np.random.random((10,10)), dtype=tf.float64) for i in range(3)]
Ys = [Xs[0]*Xs[1]*Xs[2], Xs[0]/Xs[1]*Xs[2], Xs[0]/Xs[1]/Xs[2]]

def compute_grad(YX):
    return tf.gradients(YX[0], YX[1])

tp = ThreadPool(3)
res = tp.map(compute_grad, zip(Ys, Xs))
print(res)

這是我嘗試實際代碼時遇到的部分回溯。 這是ThreadPool版本。

File "pathos/threading.py", line 134, in map
    return _pool.map(star(f), zip(*args)) # chunksize
  File "multiprocess/pool.py", line 260, in map
    return self._map_async(func, iterable, mapstar, chunksize).get()
  File "multiprocess/pool.py", line 608, in get
    raise self._value
  File "multiprocess/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "multiprocess/pool.py", line 44, in mapstar
    return list(map(*args))
  File "pathos/helpers/mp_helper.py", line 15, in <lambda>
    func = lambda args: f(*args)
  File "my_code.py", line 939, in gradients_with_index
    return (tf.gradients(Y, variables), b_idx)
  File "tensorflow/python/ops/gradients_impl.py", line 448, in gradients
    colocate_gradients_with_ops)
  File "tensorflow/python/ops/gradients_impl.py", line 188, in _PendingCount
    between_op_list, between_ops, colocate_gradients_with_ops)
  File "tensorflow/python/ops/control_flow_ops.py", line 1288, in MaybeCreateControlFlowState
    loop_state.AddWhileContext(op, between_op_list, between_ops)
  File "tensorflow/python/ops/control_flow_ops.py", line 1103, in AddWhileContext
    grad_state = GradLoopState(forward_ctxt, outer_grad_state)
  File "tensorflow/python/ops/control_flow_ops.py", line 737, in __init__
    cnt, outer_grad_state)
  File "tensorflow/python/ops/control_flow_ops.py", line 2282, in AddBackPropLoopCounter
    merge_count = merge([enter_count, enter_count])[0]
  File "tensorflow/python/ops/control_flow_ops.py", line 404, in merge
    return gen_control_flow_ops._merge(inputs, name)
  File "tensorflow/python/ops/gen_control_flow_ops.py", line 150, in _merge
    result = _op_def_lib.apply_op("Merge", inputs=inputs, name=name)
  File "tensorflow/python/framework/op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "tensorflow/python/framework/ops.py", line 2506, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "tensorflow/python/framework/ops.py", line 1273, in __init__
    self._control_flow_context.AddOp(self)
  File "tensorflow/python/ops/control_flow_ops.py", line 2147, in AddOp
    self._AddOpInternal(op)
  File "tensorflow/python/ops/control_flow_ops.py", line 2177, in _AddOpInternal
    self._MaybeAddControlDependency(op)
  File "tensorflow/python/ops/control_flow_ops.py", line 2204, in _MaybeAddControlDependency
    op._add_control_input(self.GetControlPivot().op)
AttributeError: 'NoneType' object has no attribute 'op'

這是另一個回溯。 注意錯誤是不同的

Traceback (most recent call last):
  File "tensorflow/python/ops/control_flow_ops.py", line 869, in AddForwardAccumulator
    enter_acc = self.forward_context.AddValue(acc)
  File "tensorflow/python/ops/control_flow_ops.py", line 2115, in AddValue
    self._outer_context.AddInnerOp(enter.op)
  File "tensorflow/python/framework/ops.py", line 3355, in __exit__
    self._graph._pop_control_dependencies_controller(self)
  File "tensorflow/python/framework/ops.py", line 3375, in _pop_control_dependencies_controller
    assert self._control_dependencies_stack[-1] is controller
AssertionError

ProcessPool版本遇到錯誤:

_pickle.PicklingError: Can't pickle <class 'tensorflow.python.util.tf_should_use._add_should_use_warning.<locals>.TFShouldUseWarningWrapper'>: it's not found as tensorflow.python.util.tf_should_use._add_should_use_warning.<locals>.TFShouldUseWarningWrapper

tf.gradients()函數不是線程安全的 它會對圖形進行一系列復雜且非原子的修改,並且這些修改不受鎖的保護。 特別是,似乎在包含控制流操作(例如tf.while_loop() )的圖上使用tf.gradients()可能會在您同時運行時出現問題。

請注意,即使以線程安全的方式實現對tf.gradients()並行調用也不可能加快它的速度。 該函數不執行任何I / O操作,也不調用釋放Python的GIL的任何本機方法,因此執行很可能會被序列化。 實現基於multiprocessing的並行性將需要額外的系統調用來訪問共享圖(以及獲取/釋放鎖),因此這不太可能更快。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM