![](/img/trans.png)
[英]CIFAR-10 TensorFlow: InvalidArgumentError (see above for traceback): logits and labels must be broadcastable
[英]Train a NN for a subclass of cifar in Tensorflow : map labels
我想你可能想做這樣的事情(這將在你將數據/標簽提供給 .network 之前作為預處理步驟完成):
# dummy array/labels
arr = [1, 3, 6, 8, 10, 4, 7, 15, 25, 19]
print('original array =', arr)
# create mapping (range(20) in your case)
mapping = dict(zip(set(arr), range(10)))
print('mapping =', mapping)
# apply the mapping
new_arr = list(map(lambda x: mapping[x], arr))
print('new array =', new_arr)
# >> output:
# original array = [1, 3, 6, 8, 10, 4, 7, 15, 25, 19]
# mapping = {1: 0, 3: 1, 4: 2, 6: 3, 7: 4, 8: 5, 10: 6, 15: 7, 19: 8, 25: 9}
# new array = [0, 1, 3, 5, 6, 2, 4, 7, 9, 8]
所以基本上你的原始標簽(這是 len(set(labels)) = 20 但值 > 20,如果我理解正確的話)被映射到最小的可能值,以便它可以處理你的損失 function。它可能如果您需要將 map 標簽恢復為原始值,請保留映射以備后用。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.