简体   繁体   中英

Vectorized beam search decoder is not faster on GPU - Tensorflow 2

I'm trying to run a RNN beam search on a tf.keras.Model in a vectorized way to have it work completely on GPU. However, despite having everything as tf.function , as vectorized as I can make it, it runs exactly the same speed with or without a GPU. Attached is a minimal example with a fake model. In reality, for n=32, k=32, steps=128 which is what I would want to work with, this takes 20s (per n=32 samples) to decode, both on CPU and on GPU!

I must be missing something. When I train the model, on GPU a training iteration (128 steps) with batch size 512 takes 100ms, and on CPU a training iteration with batch size 32 takes 1 sec. The GPU isn't saturated at batch size 512. I get that I have overhead from doing the steps individually and doing a blocking operation per step, but in terms of computation my overhead is negligible compared to the rest of the model.

I also get that using a tf.keras.Model in this way is probably not ideal, but is there another way to wire output tensors via a function back to the input tensors, and particularly also rewire the states?

Full working example: https://gist.github.com/meowcat/e3eaa4b8543a7c8444f4a74a9074b9ae

@tf.function
def decode_beam(states_init, scores_init, y_init, steps, k, n):    
    states = states_init
    scores = scores_init
    xstep = embed_y_to_x(y_init)

    # Keep the results in TensorArrays
    y_chain = tf.TensorArray(dtype="int32", size=steps)
    sequences_chain = tf.TensorArray(dtype="int32", size=steps)
    scores_chain = tf.TensorArray(dtype="float32", size=steps)


    for i in range(steps):
        # model_decode is the trained model with 3.5 million trainable params.
        # Run a single step of the RNN model.
        y, states = model_decode([xstep, states])
        # Add scores of step n to previous scores
        # (I left out the sequence end killer for this demo)
        scores_y = tf.expand_dims(tf.reshape(scores, y.shape[:-1]), 2) + tm.log(y)
        # Reshape into (n,k,tokens) and find the best k sequences to continue for each of n candidates
        scores_y = tf.reshape(scores_y, [n, -1])
        top_k = tm.top_k(scores_y, k, sorted=False)
        # Transform the indices. I was using tf.unravel_index but
        # `tf.debugging.set_log_device_placement(True)` indicated that this would be placed on the CPU
        # thus I rewrote it
        top_k_index = tf.reshape(
                top_k[1] + tf.reshape(tf.range(n), (-1, 1)) * scores_y.shape[1], [-1])
        ysequence = top_k_index // y.shape[2]
        ymax = top_k_index % y.shape[2]
        # this gives us two (n*k,) tensors with parent sequence (ysequence) 
        # and chosen character (ymax) per sequence.
        # For continuation, pick the states, and "return" the scores
        states = tf.gather(states, ysequence)
        scores = tf.reshape(top_k[0], [-1])
        # Write the results into the TensorArrays,
        # and embed for the next step
        xstep = embed_y_to_x(ymax)
        y_chain = y_chain.write(i, ymax)
        sequences_chain = sequences_chain.write(i, ysequence)
        scores_chain = scores_chain.write(i, scores)
    # Done: Stack up the results and return them
    sequences_final = sequences_chain.stack()
    y_final = y_chain.stack()
    scores_final = scores_chain.stack()

    return sequences_final, y_final, scores_final

There was a lot going on here. I will comment on it because it might help others to resolve TensorFlow performance issues.

Profiling

  • The GPU profiler library (cupti) was not loading correctly on the cluster, stopping me from doing any useful profiling on the GPU. That was fixed, so I get useful profiles of the GPU now.

Note this very useful answer (the only one on the web) that shows how to profile arbitrary TensorFlow 2 code, rather than Keras training:

https://stackoverflow.com/a/56698035/1259675

logdir = "log"
writer = tf.summary.create_file_writer(logdir)
tf.summary.trace_on(graph=True, profiler=True)

# run any @tf.function decorated functions here

sequences, y, scores = decode_beam_steps(
    y_init, states_init, scores_init, 
    steps = steps, k = k, n = n, pad_mask = pad_mask)  

with writer.as_default():
    tf.summary.trace_export(name="model_trace", step=0, profiler_outdir=logdir)
tf.summary.trace_off()

Note that an old Chromium version is needed to look at the profiling results, since at the time (4-17-20) this fails in current Chrome/Chromium.

Small optimizations

  • The graph was made a bit lighter but not significantly faster by using unroll=True in the LSTM cells used by the model (not shown here), since only one step is needed so the symbolic loop only adds clutter. This significantly slashes time for the first iteration of the function above, when AutoGraph builds the graph. Note that this time is enormous (see below).

    unroll=False (the default) builds in 300 seconds, unroll=True builds in 100 seconds. Note that the performance itself stays the same (15-20 sec/iteration for n=32, k=32).

implementation=1 made it slightly slower, so I stayed with the default of implementation=2 .

Using tf.while_loop instead of relying on AutoGraph

  • The for i in range(steps) loop. I had this both in the (above shown) inlined version, and in a modularized one:
    for i in range(steps):
        ystep, states = model_decode([xstep, states])
        ymax, ysequence, states, scores = model_beam_step(
            ystep, states, scores, k, n, pad_mask)
        xstep = model_rtox(ymax)
        y_chain = y_chain.write(i, ymax)
        sequences_chain = sequences_chain.write(i, ysequence)
        scores_chain = scores_chain.write(i, scores)

where model_beam_step does all the beam search math. Unsurprisingly, both performed exactly equally bad , and in particular, both took ~100/300 seconds on the first run when AutoGraph traced the graph. Further, tracing the graph with the profiler gives a crazy 30-50mb file that won't easily load on Tensorboard and more or less crash it. The profile had dozens of parallel GPU streams with a single operation each.

Substituting this with a tf.while_loop slashed the setup time to zero ( back_prop=False makes only very little difference), and produces a nice 500kb graph that can easily be looked at in TensorBoard and profiled in an useful way with 4 GPU streams.


    beam_steps_cond = lambda i, y_, seq_, sc_, xstep, states, scores: i < steps
    def decode_beam_steps_body(i, y_, seq_, sc_, xstep, states, scores):
        y, states = model_decode([xstep, states])
        ymax, ysequence, states, scores = model_beam_step(
                y, states, scores, k, n, pad_mask)
        xstep = model_rtox(ymax)
        y_ = y_.write(i, ymax)
        seq_ = seq_.write(i, ysequence)
        sc_= sc_.write(i, scores)
        i = i + 1
        return i, y_, seq_, sc_, xstep, states, scores
    
    _, y_chain, sequences_chain, scores_chain, _, _, _ = \
        tf.while_loop(
            cond = beam_steps_cond,
            body = decode_beam_steps_body,
            loop_vars = [i, y_chain, sequences_chain, scores_chain,
                         xstep, states, scores],
            back_prop = False
            )

Finally, the real problem

That I was actually able to look at the profile in a meaningful way showed me that the real issue was an output postprocessing function that runs on CPU. I didn't suspect it because it was running fast earlier, but I ignored that a beam search modification I made leads to >>>k sequences per candidate, which massively slows processing down. Thus, it was slashing every benefit I could gain from being efficient on GPU with the decoding step. Without this postprocessing, GPU runs >2 iterations / sec. Refactoring the postprocessing (which is extremely fast if done right) into TensorFlow resolved the issue.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM