繁体   English   中英

如何使用TensorFlow 2.0计算CategoricalCrossentropy?

[英]How to compute CategoricalCrossentropy using TensorFlow 2.0?

我正在尝试了解TensorFlow 2.0中的CategoricalCrossentropy()损失函数。 当我使用

tf.keras.metrics.CategoricalCrossentropy(actual, pred)

我收到以下错误:

ValueError:具有多个元素的数组的真值不明确。 使用a.any()或a.all()

谁能解释如何纠正它?

docs中 ,要计算交叉熵,请使用

cce = tf.keras.metrics.CategoricalCrossentropy()
# cce.update_state(target, prediction)
cce.update_state([[0, 1, 0], [0, 0, 1]], [[0.05, 0.95, 0], [0.1, 0.8, 0.1]])

cce.result().numpy()
# 1.1769392

OTOH,如果您要计算交叉熵损失 ,请改用tf.keras.losses.CategoricalCrossentropy

cce = tf.keras.metrics.CategoricalCrossentropy()
cce([[0, 1, 0], [0, 0, 1]], [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]).numpy()
# 1.1769392

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM