[英]How can I solve “Could not find valid device” issue with CTC loss in tensorflow 2?
我想計算 OCR 問題的 CTC 損失,但是每當我運行代碼時,它都會導致:
NotFoundError:找不到節點的有效設備。 節點:{{node OneHot}} 為 op OneHot 注冊的所有內核:
我在谷歌 colab 上使用 tensorflow 2。
以下是代碼的關鍵部分:
def calculate_ctc_loss(predictions, labels, label_length, logit_length):
# shape of predictions (batch_size, max_label_seq_length) --> (64, 20)
# shape of labels (batch_size, timeframes, dictionary_size) --> (64, 20, 30)
label_length_tensor = tf.constant(label_length, shape=(labels.shape[0], 1))
logit_length_tensor = tf.constant(logit_length, shape=(labels.shape[0], 1))
# label_length is a scalar, here 20. the same for logit_length
logits = tf.transpose(predictions, (1, 0, 2))
loss = tf.nn.ctc_loss(labels, logits, logit_length_tensor, label_length_tensor)
return loss
我希望得到計算損失,但我得到以下異常:
NotFoundError Traceback (most recent call last)
<ipython-input-236-6ddd6739d50d> in <module>()
14 #print(batch)
15 #print(target)
---> 16 batch_loss, t_loss, oc_loss, accuracy = train_step(img_tensor, target)
17 total_loss += t_loss
18 one_char_loss += oc_loss
12 frames
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)
NotFoundError: Could not find valid device for node.
Node:{{node OneHot}}
All kernels registered for op OneHot :
device='XLA_CPU'; TI in [DT_INT32, DT_UINT8, DT_INT64]; T in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT8, ..., DT_BFLOAT16, DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64]
device='XLA_CPU_JIT'; TI in [DT_INT32, DT_UINT8, DT_INT64]; T in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT8, ..., DT_BFLOAT16, DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64]
device='XLA_GPU_JIT'; TI in [DT_INT32, DT_UINT8, DT_INT64]; T in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT8, ..., DT_BFLOAT16, DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64]
device='GPU'; TI in [DT_INT64]; T in [DT_INT64]
device='GPU'; TI in [DT_INT32]; T in [DT_INT64]
device='GPU'; TI in [DT_UINT8]; T in [DT_INT64]
device='GPU'; TI in [DT_INT64]; T in [DT_INT32]
device='GPU'; TI in [DT_INT32]; T in [DT_INT32]
device='GPU'; TI in [DT_UINT8]; T in [DT_INT32]
device='GPU'; TI in [DT_INT64]; T in [DT_BOOL]
device='GPU'; TI in [DT_INT32]; T in [DT_BOOL]
device='GPU'; TI in [DT_UINT8]; T in [DT_BOOL]
device='GPU'; TI in [DT_INT64]; T in [DT_DOUBLE]
device='GPU'; TI in [DT_INT32]; T in [DT_DOUBLE]
device='GPU'; TI in [DT_UINT8]; T in [DT_DOUBLE]
device='GPU'; TI in [DT_INT64]; T in [DT_FLOAT]
device='GPU'; TI in [DT_INT32]; T in [DT_FLOAT]
device='GPU'; TI in [DT_UINT8]; T in [DT_FLOAT]
device='GPU'; TI in [DT_INT64]; T in [DT_HALF]
device='GPU'; TI in [DT_INT32]; T in [DT_HALF]
device='GPU'; TI in [DT_UINT8]; T in [DT_HALF]
將label的類型改為tf.int64/tf.int32
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.