简体   繁体   English

在 Tensorflow 2.2 中使用具有 SparseCategoricalCrossEntropy 损失的 tf.metrics.MeanIoU() 时出现尺寸不匹配错误

[英]Dimensions mismatch error when using tf.metrics.MeanIoU() with SparseCategoricalCrossEntropy loss in Tensorflow 2.2

Refer to # https://github.com/tensorflow/tensorflow/issues/32875参考# https://github.com/tensorflow/tensorflow/issues/32875

The suggested fix was to:建议的修复是:

class UpdatedMeanIoU(tf.keras.metrics.MeanIoU):
    @tf.function
    def __call__(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.argmax(y_pred, axis=-1) # this is the fix
        return super().__call__(y_true, y_pred, sample_weight=sample_weight)

It worked for TF2.1, but broke again in TF2.2.它适用于 TF2.1,但在 TF2.2 中再次中断。 Is there a way to pass y_pred = tf.argmax(y_pred, axis=-1) as y_pred to this metric other than subclassing?除了子类化之外,有没有办法将y_pred = tf.argmax(y_pred, axis=-1)作为y_pred给这个指标?

This fixes the issue:这解决了这个问题:

class UpdatedMeanIoU(tf.keras.metrics.MeanIoU):
  def __init__(self,
               y_true=None,
               y_pred=None,
               num_classes=None,
               name=None,
               dtype=None):
    super(UpdatedMeanIoU, self).__init__(num_classes = num_classes,name=name, dtype=dtype)

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_pred = tf.math.argmax(y_pred, axis=-1)
    return super().update_state(y_true, y_pred, sample_weight)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 tf.keras.metrics.MeanIoU 结果没有改善 - tf.keras.metrics.MeanIoU outcome is not improving 使用RNN在Tensorflow中采样的Softmax损失-尺寸不匹配问题 - Sampled Softmax Loss in Tensorflow with RNN - Dimensions mismatch problem 使用自定义损失函数时出现 Tensorflow 错误 - Tensorflow error when using Custom Loss function 在tf.metrics.mean_absolute_error中除以0时,张量流的奇怪行为 - Strange behavior of tensorflow when dividing by 0 in tf.metrics.mean_absolute_error 使用 TF Estimator 时 Tensorflow 分布式训练的损失和学习率缩放策略 - Loss and learning rate scaling strategies for Tensorflow distributed training when using TF Estimator Tensorflow Estimator:使用tf.feature_column.embedding_column作为类别变量列表时的损失不会减少 - Tensorflow Estimator: loss not decreasing when using tf.feature_column.embedding_column for a list of categorical variables SparseCategoricalCrossentropy 形状不匹配 - SparseCategoricalCrossentropy Shape Mismatch 将 tensorflow 导入为 tf 时导入 Tensorflow 时出错 - Error importing Tensorflow when import tensorflow as tf 使用 yolo4.cfg 进行 tensorflow 2.2 训练中的形状不匹配问题 - Shape mismatch problem in tensorflow 2.2 training using yolo4.cfg tf.keras.losses.SparseCategoricalCrossentropy() 与“sparse_categorical_crossentropy”作为损失的区别 - difference between tf.keras.losses.SparseCategoricalCrossentropy() vs “sparse_categorical_crossentropy” as loss
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM