简体   繁体   中英

How to test distributed layers on Tensorflow?

I am trying to test a layer that I will add later in a distributed model however I want to be sure that it works before.

This is the layer in question:

class BNShuffler(tf.Module):
    def __init__(
        self,
        global_batch_size: int=64
    ):
        super(BNShuffler, self).__init__()
        self.global_batch_size = global_batch_size
        self.idx = tf.Variable(tf.range(global_batch_size), trainable=False)

    def __call__(self, x, shuffle=True):
        batch_size = tf.shape(x)[0]
        replica_context = tf.distribute.get_replica_context()
        if replica_context is not None:
            replica_id   = replica_context.replica_id_in_sync_group
            num_replicas = replica_context.num_replicas_in_sync
            x_target = _cross_replica_concat(x, replica_id, replica_context, num_replicas)
        else:
            x_target = x
            num_replicas = 1
            replica_id = 0

        if shuffle:
            self.idx.assign(tf.random.shuffle(self.idx))
            x_shuffled = tf.gather(x_target, self.idx)
            return x_shuffled[replica_id * batch_size: (replica_id + 1) * batch_size]
        
        else:
            unshuffled_idx = tf.math.invert_permutation(self.idx)
            x_unshuffled   = tf.gather(x_target, unshuffled_idx)
            self.idx.assign(tf.range(self.global_batch_size))
            return x_unshuffled[replica_id * batch_size: (replica_id + 1) * batch_size]


def _cross_replica_concat(x, replica_id, replica_context, num_replicas):
    x_shape = tf.shape(x)
    result_tensor = tf.scatter_nd(
        indices=[[replica_id]],
        updates=[x],
        shape=tf.concat([[num_replicas], x_shape], axis=0),
    )
    result_tensor = replica_context.all_reduce(
        tf.distribute.ReduceOp.SUM, result_tensor
    )

    return tf.reshape(result_tensor, x_shape)

The goal of this layer is to shuffle data across all gpu when shuffle=True and put them back when shuffle=False so it will be applied twice.

In order to test it I tried to generate a simple distributed dataset and apply my shuffler based onthis tutorial but it throws me an error.

Code:

global_batch_size = 6
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
train_dataset = tf.data.Dataset.range(12).batch(global_batch_size, drop_remainder=False)
dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
with strategy.scope():
    shuffler=BNShuffler(global_batch_size)

def train_step(x):
    print(x)
    x = tf.reshape(x, [-1, 1, 1])
    x = shuffler(x, True)
    print(x)
    x = shuffler(x, False)
    print(x)
    return strategy.reduce(tf.distribute.ReduceOp.SUM, x,
                         axis=None)


for epoch in range(1):
    # TRAIN LOOP
    total_loss = 0.0
    num_batches = 0
    for x in dist_dataset:
        strategy.run(train_step, args=(x,))

Output:

tf.Tensor([0 1 2], shape=(3,), dtype=int64)
tf.Tensor([3 4 5], shape=(3,), dtype=int64)
INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:Error reported to Coordinator: 2 root error(s) found.
  (0) Internal:  NCCL: invalid usage. Set NCCL_DEBUG=WARN for detail.
     [[node NcclAllReduce_1 (defined at <ipython-input-22-0888591f3d27>:17) ]]
     [[NcclAllReduce_1/_6]]
  (1) Internal:  NCCL: invalid usage. Set NCCL_DEBUG=WARN for detail.
     [[node NcclAllReduce_1 (defined at <ipython-input-22-0888591f3d27>:17) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference__all_reduce_1284]

Function call stack:
_all_reduce -> _all_reduce
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/coordinator.py", line 297, in stop_on_exception
    yield
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/mirrored_run.py", line 228, in _call_for_each_replica
    **merge_kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py", line 572, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py", line 3080, in batch_all_reduce
    options)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py", line 2374, in batch_reduce_to
    return self._batch_reduce_to(reduce_op, value_destination_pairs, options)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/mirrored_strategy.py", line 697, in _batch_reduce_to
    options=self._communication_options.merge(options))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/cross_device_ops.py", line 426, in batch_reduce
    options)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/cross_device_ops.py", line 819, in batch_reduce_implementation
    [v[0] for v in value_destination_pairs])
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/cross_device_ops.py", line 831, in _batch_all_reduce
    dense_results = self._do_batch_all_reduce(reduce_op, dense_values)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/cross_device_ops.py", line 860, in _do_batch_all_reduce
    device_grad_packs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/cross_device_utils.py", line 45, in aggregate_gradients_using_nccl
    agg_grads = nccl_ops.all_sum(single_grads)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/nccl_ops.py", line 47, in all_sum
    return _apply_all_reduce('sum', tensors)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/nccl_ops.py", line 234, in _apply_all_reduce
    return def_function.function(_all_reduce)()
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 895, in _call
    filtered_flat_args, self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 1919, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 560, in call
    ctx=ctx)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InternalError: 2 root error(s) found.
  (0) Internal:  NCCL: invalid usage. Set NCCL_DEBUG=WARN for detail.
     [[node NcclAllReduce_1 (defined at <ipython-input-22-0888591f3d27>:17) ]]
     [[NcclAllReduce_1/_6]]
  (1) Internal:  NCCL: invalid usage. Set NCCL_DEBUG=WARN for detail.
     [[node NcclAllReduce_1 (defined at <ipython-input-22-0888591f3d27>:17) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference__all_reduce_1284]

Function call stack:
_all_reduce -> _all_reduce

---------------------------------------------------------------------------
InternalError                             Traceback (most recent call last)
<ipython-input-22-0888591f3d27> in <module>
     15     num_batches = 0
     16     for x in dist_dataset:
---> 17         strategy.run(train_step, args=(x,))

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py in run(***failed resolving arguments***)
   1257       fn = autograph.tf_convert(
   1258           fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
-> 1259       return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
   1260 
   1261   def reduce(self, reduce_op, value, axis):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py in call_for_each_replica(self, fn, args, kwargs)
   2728       kwargs = {}
   2729     with self._container_strategy().scope():
-> 2730       return self._call_for_each_replica(fn, args, kwargs)
   2731 
   2732   def _call_for_each_replica(self, fn, args, kwargs):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/mirrored_strategy.py in _call_for_each_replica(self, fn, args, kwargs)
    627   def _call_for_each_replica(self, fn, args, kwargs):
    628     return mirrored_run.call_for_each_replica(
--> 629         self._container_strategy(), fn, args, kwargs)
    630 
    631   def _configure(self,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/mirrored_run.py in call_for_each_replica(strategy, fn, args, kwargs)
     91     fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
     92 
---> 93   return _call_for_each_replica(strategy, fn, args, kwargs)
     94 
     95 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/mirrored_run.py in _call_for_each_replica(distribution, fn, args, kwargs)
    232     for t in threads:
    233       t.should_run.set()
--> 234     coord.join(threads)
    235 
    236   return distribute_utils.regroup(tuple(t.main_result for t in threads))

/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/coordinator.py in join(self, threads, stop_grace_period_secs, ignore_live_threads)
    387       self._registered_threads = set()
    388       if self._exc_info_to_raise:
--> 389         six.reraise(*self._exc_info_to_raise)
    390       elif stragglers:
    391         if ignore_live_threads:

/usr/local/lib/python3.6/dist-packages/six.py in reraise(tp, value, tb)
    701             if value.__traceback__ is not tb:
    702                 raise value.with_traceback(tb)
--> 703             raise value
    704         finally:
    705             value = None

/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/coordinator.py in stop_on_exception(self)
    295     """
    296     try:
--> 297       yield
    298     except:  # pylint: disable=bare-except
    299       self.request_stop(ex=sys.exc_info())

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/mirrored_run.py in _call_for_each_replica(distribution, fn, args, kwargs)
    226               variable_scope.variable_scope(mtt_captured_var_scope):
    227             merge_result = threads[0].merge_fn(distribution, *merge_args,
--> 228                                                **merge_kwargs)
    229           for r, t in enumerate(threads):
    230             t.merge_result = distribute_utils.select_replica(r, merge_result)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    570   def wrapper(*args, **kwargs):
    571     with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED):
--> 572       return func(*args, **kwargs)
    573 
    574   if inspect.isfunction(func) or inspect.ismethod(func):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py in batch_all_reduce(strategy, *value_flat)
   3078       return strategy.extended.batch_reduce_to(
   3079           reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat],
-> 3080           options)
   3081 
   3082     if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py in batch_reduce_to(self, reduce_op, value_destination_pairs, options)
   2372     if isinstance(reduce_op, six.string_types):
   2373       reduce_op = reduce_util.ReduceOp(reduce_op.upper())
-> 2374     return self._batch_reduce_to(reduce_op, value_destination_pairs, options)
   2375 
   2376   def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/mirrored_strategy.py in _batch_reduce_to(self, reduce_op, value_destination_pairs, options)
    695         reduce_op,
    696         value_destination_pairs,
--> 697         options=self._communication_options.merge(options))
    698 
    699   def _update(self, var, fn, args, kwargs, group):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/cross_device_ops.py in batch_reduce(self, reduce_op, value_destination_pairs, options)
    424       options = collective_util.Options()
    425     return self.batch_reduce_implementation(reduce_op, value_destination_pairs,
--> 426                                             options)
    427 
    428   def broadcast(self, tensor, destinations):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/cross_device_ops.py in batch_reduce_implementation(self, reduce_op, value_destination_pairs, options)
    817     if _all_devices_match(value_destination_pairs):
    818       return self._batch_all_reduce(reduce_op,
--> 819                                     [v[0] for v in value_destination_pairs])
    820     else:
    821       return [

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/cross_device_ops.py in _batch_all_reduce(self, reduce_op, per_replica_values)
    829         cross_device_utils.split_by_sparsity(per_replica_values))
    830     if dense_values:
--> 831       dense_results = self._do_batch_all_reduce(reduce_op, dense_values)
    832     else:
    833       dense_results = []

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/cross_device_ops.py in _do_batch_all_reduce(self, reduce_op, dense_values)
    858       # TODO(yuefengz): merge this into the all-reduce library.
    859       reduced = cross_device_utils.aggregate_gradients_using_nccl(
--> 860           device_grad_packs)
    861     else:
    862       # TODO(yuefengz): check that gpu ids in `destinations` are in ascending

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/cross_device_utils.py in aggregate_gradients_using_nccl(replica_grads)
     43   for single_g_and_v in zip(*replica_grads):
     44     single_grads = [g for g, _ in single_g_and_v]
---> 45     agg_grads = nccl_ops.all_sum(single_grads)
     46     agg_all_g_and_v.append(
     47         [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)])

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/nccl_ops.py in all_sum(tensors)
     45     the same device as `tensors[i]`.
     46   """
---> 47   return _apply_all_reduce('sum', tensors)
     48 
     49 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/nccl_ops.py in _apply_all_reduce(reduction, tensors)
    232     # Nccl ops will block unless they are executed concurrently such as in a
    233     # graph or a defun.
--> 234     return def_function.function(_all_reduce)()
    235   else:
    236     return _all_reduce()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    826     tracing_count = self.experimental_get_tracing_count()
    827     with trace.Trace(self._name) as tm:
--> 828       result = self._call(*args, **kwds)
    829       compiler = "xla" if self._experimental_compile else "nonXla"
    830       new_tracing_count = self.experimental_get_tracing_count()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    893       # If we did not create any variables the trace we have is good enough.
    894       return self._concrete_stateful_fn._call_flat(
--> 895           filtered_flat_args, self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access
    896 
    897     def fn_with_cond(inner_args, inner_kwds, inner_filtered_flat_args):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1917       # No tape is watching; skip to running the function.
   1918       return self._build_call_outputs(self._inference_function.call(
-> 1919           ctx, args, cancellation_manager=cancellation_manager))
   1920     forward_backward = self._select_forward_and_backward_functions(
   1921         args,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
    558               inputs=args,
    559               attrs=attrs,
--> 560               ctx=ctx)
    561         else:
    562           outputs = execute.execute_with_cancellation(

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InternalError: 2 root error(s) found.
  (0) Internal:  NCCL: invalid usage. Set NCCL_DEBUG=WARN for detail.
     [[node NcclAllReduce_1 (defined at <ipython-input-22-0888591f3d27>:17) ]]
     [[NcclAllReduce_1/_6]]
  (1) Internal:  NCCL: invalid usage. Set NCCL_DEBUG=WARN for detail.
     [[node NcclAllReduce_1 (defined at <ipython-input-22-0888591f3d27>:17) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference__all_reduce_1284]

Function call stack:
_all_reduce -> _all_reduce

How should I use the strategy to make my test ?

The major reason why you got the error messages may be because tf.distribute.get_replica_context().all_reduce() does not always work in eager mode. It will work properly in graph mode.(See example codes below)

There are also some other potential problems in your codes.

  1. pass aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA to tf.Variable to make sure it is synchronized across replicas.
  2. strategy.reduce() shouldn't be called inside train_step

Example codes:

tf.random.set_seed(88883)
strategy = tf.distribute.MirroredStrategy()
print(f'using distribution strategy\nnumber of gpus:{strategy.num_replicas_in_sync}')

global_batch_size = 6
with strategy.scope():
    idx=tf.Variable(tf.range(global_batch_size),trainable=False,aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
train_dataset = tf.data.Dataset.range(12).batch(global_batch_size, drop_remainder=True)
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
train_dataset = train_dataset.with_options(options)
ds = strategy.experimental_distribute_dataset(train_dataset)

def shuffler(x,idx,global_batch_size,shuffle=False):
    ctx = tf.distribute.get_replica_context()
    num_replicas = ctx.num_replicas_in_sync
    replica_id = ctx.replica_id_in_sync_group

    x_shape = global_batch_size//num_replicas
    result_tensor = tf.scatter_nd([[replica_id]],[x],[num_replicas, x_shape])
    result_tensor = tf.reshape(result_tensor, [global_batch_size])

    all_x = ctx.all_reduce(
        tf.distribute.ReduceOp.SUM, result_tensor
    )

    if shuffle:
        idx.assign(tf.random.shuffle(idx))
        x_shuffled = tf.gather(all_x, idx)
        return x_shuffled[replica_id * x_shape: (replica_id + 1) * x_shape]

    else:
        unshuffled_idx = tf.math.invert_permutation(idx)
        x_unshuffled   = tf.gather(all_x, unshuffled_idx)
        idx.assign(tf.range(global_batch_size))
        return x_unshuffled[replica_id * x_shape: (replica_id + 1) * x_shape]

def train_step(x):
    replica_id = tf.distribute.get_replica_context().replica_id_in_sync_group
    tf.print('before shuffle',replica_id, x, output_stream=sys.stdout)
    x = shuffler(x,idx,global_batch_size,shuffle = True)
    tf.print('shuffle = True',replica_id, x, output_stream=sys.stdout)
    x = shuffler(x,idx,global_batch_size,shuffle = False)
    tf.print('shuffle = False',replica_id, x, output_stream=sys.stdout)

#add @tf.function to run in graph mode
@tf.function
def distributed_train_step(x):
  strategy.run(train_step, args=(x,))

list(map(distributed_train_step,ds))

Expected outputs:

using distribution strategy
number of gpus:3
before shuffle 0 [0 1]
before shuffle 1 [2 3]
before shuffle 2 [4 5]
shuffle = True 0 [4 1]
shuffle = True 1 [0 2]
shuffle = True 2 [3 5]
shuffle = False 0 [0 1]
shuffle = False 1 [2 3]
shuffle = False 2 [4 5]
before shuffle 0 [6 7]
before shuffle 1 [8 9]
before shuffle 2 [10 11]
shuffle = True 0 [9 8]
shuffle = True 1 [7 11]
shuffle = True 2 [10 6]
shuffle = False 0 [6 7]
shuffle = False 1 [8 9]
shuffle = False 2 [10 11]

Finally, please also note that tf.distribute.get_replica_context().all_gather() is made for exactly what you want to do instead of all_reduce() although that all_reduce() can do the same thing.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM