简体   繁体   中英

How can I solve “Could not find valid device” issue with CTC loss in tensorflow 2?

I want to calculate CTC loss for OCR problem, but whenever I run the code it results in:

NotFoundError: Could not find valid device for node. Node:{{node OneHot}} All kernels registered for op OneHot:

I am using tensorflow 2 on google colab.

below is the critical part of the code:

def calculate_ctc_loss(predictions, labels, label_length, logit_length):
    # shape of predictions (batch_size, max_label_seq_length) --> (64, 20)
    # shape of labels (batch_size, timeframes, dictionary_size) --> (64, 20, 30)
    label_length_tensor = tf.constant(label_length, shape=(labels.shape[0], 1)) 
    logit_length_tensor = tf.constant(logit_length, shape=(labels.shape[0], 1))
    # label_length is a scalar, here 20. the same for logit_length
    logits = tf.transpose(predictions, (1, 0, 2))
    loss = tf.nn.ctc_loss(labels, logits, logit_length_tensor, label_length_tensor)
    return loss

I expect to get calculated loss, but I get the following exception instead:

NotFoundError                             Traceback (most recent call last)

<ipython-input-236-6ddd6739d50d> in <module>()
     14         #print(batch)
     15         #print(target)
---> 16         batch_loss, t_loss, oc_loss, accuracy = train_step(img_tensor, target)
     17         total_loss += t_loss
     18         one_char_loss += oc_loss

12 frames

/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

NotFoundError: Could not find valid device for node.
Node:{{node OneHot}}
All kernels registered for op OneHot :
  device='XLA_CPU'; TI in [DT_INT32, DT_UINT8, DT_INT64]; T in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT8, ..., DT_BFLOAT16, DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64]
  device='XLA_CPU_JIT'; TI in [DT_INT32, DT_UINT8, DT_INT64]; T in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT8, ..., DT_BFLOAT16, DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64]
  device='XLA_GPU_JIT'; TI in [DT_INT32, DT_UINT8, DT_INT64]; T in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT8, ..., DT_BFLOAT16, DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64]
  device='GPU'; TI in [DT_INT64]; T in [DT_INT64]
  device='GPU'; TI in [DT_INT32]; T in [DT_INT64]
  device='GPU'; TI in [DT_UINT8]; T in [DT_INT64]
  device='GPU'; TI in [DT_INT64]; T in [DT_INT32]
  device='GPU'; TI in [DT_INT32]; T in [DT_INT32]
  device='GPU'; TI in [DT_UINT8]; T in [DT_INT32]
  device='GPU'; TI in [DT_INT64]; T in [DT_BOOL]
  device='GPU'; TI in [DT_INT32]; T in [DT_BOOL]
  device='GPU'; TI in [DT_UINT8]; T in [DT_BOOL]
  device='GPU'; TI in [DT_INT64]; T in [DT_DOUBLE]
  device='GPU'; TI in [DT_INT32]; T in [DT_DOUBLE]
  device='GPU'; TI in [DT_UINT8]; T in [DT_DOUBLE]
  device='GPU'; TI in [DT_INT64]; T in [DT_FLOAT]
  device='GPU'; TI in [DT_INT32]; T in [DT_FLOAT]
  device='GPU'; TI in [DT_UINT8]; T in [DT_FLOAT]
  device='GPU'; TI in [DT_INT64]; T in [DT_HALF]
  device='GPU'; TI in [DT_INT32]; T in [DT_HALF]
  device='GPU'; TI in [DT_UINT8]; T in [DT_HALF]

Change the type of label to tf.int64/tf.int32

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM