简体   繁体   中英

Tensorflow Transformer ValueError: Dimension must be 5 but is 4

I am trying to follow this tutorial for a transformer model.
When I run this code:

for epoch in range(EPOCHS):
  start = time.time()

  train_loss.reset_states()
  train_accuracy.reset_states()

  # inp -> portuguese, tar -> english
  for (batch, (inp, tar)) in enumerate(train_batches):
    train_step(inp, tar)

I get the following error:

 ValueError: Dimension must be 5 but is 4 for '{{node transformer_1/decoder_2/decoder_layer_5/multi_head_attention_20/transpose_3}} = Transpose[T=DT_FLOAT, Tperm=DT_INT32](transformer_1/decoder_2/decoder_layer_5/multi_head_attention_20/MatMul_1, transformer_1/decoder_2/decoder_layer_5/multi_head_attention_20/transpose_3/perm)' with input shapes: [?,?,8,?,16], [4].

Here is the full stack trace:

 <ipython-input-55-a445c57427f6>:21 train_step  *
        predictions, _ = transformer(inp, tar_inp,
    <ipython-input-42-150e34827f23>:20 call  *
        dec_output, attention_weights = self.decoder(
    <ipython-input-40-6f1a58379354>:29 call  *
        x, block1, block2 = self.dec_layers[i](x, enc_output, training,
    <ipython-input-36-6dbff75f5f34>:22 call  *
        attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)  # (batch_size, target_seq_len, d_model)
    <ipython-input-30-24e842e0e7e6>:40 call  *
        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:206 wrapper  **
        return target(*args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/array_ops.py:2227 transpose_v2
        return transpose(a=a, perm=perm, name=name, conjugate=conjugate)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:206 wrapper
        return target(*args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/array_ops.py:2308 transpose
        return transpose_fn(a, perm, name=name)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/gen_array_ops.py:11653 transpose
        "Transpose", x=x, perm=perm, name=name)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/op_def_library.py:750 _apply_op_helper
        attrs=attr_protos, op_def=op_def)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py:601 _create_op_internal
        compute_device)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:3565 _create_op_internal
        op_def=op_def)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:2042 __init__
        control_input_ops, op_def)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:1883 _create_c_op
        raise ValueError(str(e))

Not sure if this really has any effect, but I had to change the definition of train_step_signiature, adding a None. ( Changed from tf.TensorSpec(shape=(None, None), dtype=tf.int64) to tf.TensorSpec(shape=(None, None, None), dtype=tf.int64) ).

Any ideas why this is happening?

Run the notebook in Google Colab (GPU and standard memory) and had no issue. It went through all the 20 epochs.

There is no issue with the notebook/code itself. Please restart or factory reset the instance, use GPU instance, and re-run the notebook AS IS in the Google Colab.

for epoch in range(EPOCHS):
  start = time.time()

  train_loss.reset_states()
  train_accuracy.reset_states()

  # inp -> portuguese, tar -> english
  for (batch, (inp, tar)) in enumerate(train_batches):
    train_step(inp, tar)

    if batch % 50 == 0:
      print(f'Epoch {epoch + 1} Batch {batch} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print(f'Saving checkpoint for epoch {epoch+1} at {ckpt_save_path}')

  print(f'Epoch {epoch + 1} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')

  print(f'Time taken for 1 epoch: {time.time() - start:.2f} secs\n')
Epoch 20 Batch 0 Loss 1.3774 Accuracy 0.7028
Epoch 20 Batch 50 Loss 1.3647 Accuracy 0.6955
Epoch 20 Batch 100 Loss 1.3832 Accuracy 0.6918
Epoch 20 Batch 150 Loss 1.3965 Accuracy 0.6898
Epoch 20 Batch 200 Loss 1.4004 Accuracy 0.6886
Epoch 20 Batch 250 Loss 1.4024 Accuracy 0.6880
Epoch 20 Batch 300 Loss 1.4042 Accuracy 0.6874
Epoch 20 Batch 350 Loss 1.4107 Accuracy 0.6863
Epoch 20 Batch 400 Loss 1.4136 Accuracy 0.6858
Epoch 20 Batch 450 Loss 1.4167 Accuracy 0.6856
Epoch 20 Batch 500 Loss 1.4220 Accuracy 0.6844
Epoch 20 Batch 550 Loss 1.4244 Accuracy 0.6839
Epoch 20 Batch 600 Loss 1.4289 Accuracy 0.6834
Epoch 20 Batch 650 Loss 1.4330 Accuracy 0.6827
Epoch 20 Batch 700 Loss 1.4365 Accuracy 0.6822
Epoch 20 Batch 750 Loss 1.4400 Accuracy 0.6817
Epoch 20 Batch 800 Loss 1.4438 Accuracy 0.6811
Saving checkpoint for epoch 20 at ./checkpoints/train/ckpt-4
Epoch 20 Loss 1.4444 Accuracy 0.6809
Time taken for 1 epoch: 59.03 secs

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM