簡體   English   中英

如何解決 tensorflow 2 中 CTC 丟失的“找不到有效設備”問題?

[英]How can I solve “Could not find valid device” issue with CTC loss in tensorflow 2?

我想計算 OCR 問題的 CTC 損失,但是每當我運行代碼時,它都會導致:

NotFoundError:找不到節點的有效設備。 節點:{{node OneHot}} 為 op OneHot 注冊的所有內核:

我在谷歌 colab 上使用 tensorflow 2。

以下是代碼的關鍵部分:

def calculate_ctc_loss(predictions, labels, label_length, logit_length):
    # shape of predictions (batch_size, max_label_seq_length) --> (64, 20)
    # shape of labels (batch_size, timeframes, dictionary_size) --> (64, 20, 30)
    label_length_tensor = tf.constant(label_length, shape=(labels.shape[0], 1)) 
    logit_length_tensor = tf.constant(logit_length, shape=(labels.shape[0], 1))
    # label_length is a scalar, here 20. the same for logit_length
    logits = tf.transpose(predictions, (1, 0, 2))
    loss = tf.nn.ctc_loss(labels, logits, logit_length_tensor, label_length_tensor)
    return loss

我希望得到計算損失,但我得到以下異常:

NotFoundError                             Traceback (most recent call last)

<ipython-input-236-6ddd6739d50d> in <module>()
     14         #print(batch)
     15         #print(target)
---> 16         batch_loss, t_loss, oc_loss, accuracy = train_step(img_tensor, target)
     17         total_loss += t_loss
     18         one_char_loss += oc_loss

12 frames

/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

NotFoundError: Could not find valid device for node.
Node:{{node OneHot}}
All kernels registered for op OneHot :
  device='XLA_CPU'; TI in [DT_INT32, DT_UINT8, DT_INT64]; T in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT8, ..., DT_BFLOAT16, DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64]
  device='XLA_CPU_JIT'; TI in [DT_INT32, DT_UINT8, DT_INT64]; T in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT8, ..., DT_BFLOAT16, DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64]
  device='XLA_GPU_JIT'; TI in [DT_INT32, DT_UINT8, DT_INT64]; T in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT8, ..., DT_BFLOAT16, DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64]
  device='GPU'; TI in [DT_INT64]; T in [DT_INT64]
  device='GPU'; TI in [DT_INT32]; T in [DT_INT64]
  device='GPU'; TI in [DT_UINT8]; T in [DT_INT64]
  device='GPU'; TI in [DT_INT64]; T in [DT_INT32]
  device='GPU'; TI in [DT_INT32]; T in [DT_INT32]
  device='GPU'; TI in [DT_UINT8]; T in [DT_INT32]
  device='GPU'; TI in [DT_INT64]; T in [DT_BOOL]
  device='GPU'; TI in [DT_INT32]; T in [DT_BOOL]
  device='GPU'; TI in [DT_UINT8]; T in [DT_BOOL]
  device='GPU'; TI in [DT_INT64]; T in [DT_DOUBLE]
  device='GPU'; TI in [DT_INT32]; T in [DT_DOUBLE]
  device='GPU'; TI in [DT_UINT8]; T in [DT_DOUBLE]
  device='GPU'; TI in [DT_INT64]; T in [DT_FLOAT]
  device='GPU'; TI in [DT_INT32]; T in [DT_FLOAT]
  device='GPU'; TI in [DT_UINT8]; T in [DT_FLOAT]
  device='GPU'; TI in [DT_INT64]; T in [DT_HALF]
  device='GPU'; TI in [DT_INT32]; T in [DT_HALF]
  device='GPU'; TI in [DT_UINT8]; T in [DT_HALF]

將label的類型改為tf.int64/tf.int32

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM