繁体   English   中英

Tensorflow 抛出 InvalidArgumentError:自定义丢失 function 后形状不兼容,在纪元结束之前

[英]Tensorflow throwing InvalidArgumentError: Incompatible shapes after custom loss function, before end of epoch

我从早上开始就一直在调试这个问题,但一直没有取得任何进展。 我正在尝试训练一个 Batch_size 大于 1 的 model。该代码适用于batch_size=1 ,但不适用于更大的数字。 如果可能,请提供帮助。

自定义 Loss function 的代码为(所有打印语句均在调试过程中添加):

def mobilenet_loss(y_true, y_pred):

    '''
    y_pred will be an tensor of (None, 21, 21, 3) where the last axis corresponds to the classes from class_list

    y_true will be an integer 0, 1 or 2 that would indicate which class the true value belongs to

    Since y_pred is the output of a Softmax function, the sum of elements in the last dim will be 1.0
    Using y_true, we will construct a tensor y_true_array with ones in the index of y_true and zeros in others
    We will be trying to minimize sum(y_true_array - y_pred)

    '''
    
    diagnostics = True
    if diagnostics: tf.print("ytrue =", y_true, tf.shape(y_true) ," y_pred shape =", tf.shape(y_pred))
    
    one_class_shape = tf.concat([tf.shape(y_pred)[:-1], tf.constant([1])] , axis=0) 
    if diagnostics: tf.print("one_class_shape =", one_class_shape)
        
    y_true_shape = tf.shape(y_true)
    y_true_copy = tf.identity(y_true)
    tf.print("y_true_shape - ", y_true_shape)

    y_true_compare_0 = tf.zeros(y_true_shape, dtype=tf.int64, name="compare_0_zeros")
    y_true_compare_1 = tf.ones(y_true_shape, dtype=tf.int64, name="compare_1_ones")
    tf.print("y_true_compare_0 =", tf.shape(y_true_compare_0))
    tf.print("y_true_compare_1 =", tf.shape(y_true_compare_1))
    

    def concat_0(): return tf.concat([  tf.ones(one_class_shape, dtype=tf.float32, name="loss_ones_0")
                                      , tf.zeros(one_class_shape, dtype=tf.float32, name="loss_zeros_0")
                                      , tf.zeros(one_class_shape, dtype=tf.float32, name="loss_zeros_0_1")]
                                      , axis=-1, name="concat_0")
    def concat_1(): return tf.concat([  tf.zeros(one_class_shape, dtype=tf.float32, name="loss_zeros_1")
                                      , tf.ones(one_class_shape, dtype=tf.float32, name="loss_ones_1")
                                      , tf.zeros(one_class_shape, dtype=tf.float32, name="loss_zeros_1_1")]
                                      , axis=-1, name="concat_1")
    def concat_2(): return tf.concat([  tf.zeros(one_class_shape, dtype=tf.float32, name="loss_zeros_2")
                                      , tf.zeros(one_class_shape, dtype=tf.float32, name="loss_zeros_2_1")
                                      , tf.ones(one_class_shape, dtype=tf.float32, name="loss_ones_2")]
                                      , axis=-1, name="concat_2")
    def equals_1_check():
        return tf.cond(tf.experimental.numpy.all(tf.math.equal(tf.cast(tf.reduce_sum(y_true_copy), dtype=tf.float64)
                                , tf.divide(tf.reduce_sum(y_true_compare_1)
                                            , tf.size(y_true_copy, out_type=tf.int64)
                                            ,name="div_equals_1_check")
                                , name="equals_1_check")
                        )
                        , concat_1 
                        , concat_2 
                        , name="second_condition")

    tf.print("y_true_compare_0", y_true_compare_0)
    y_true_array = tf.cond(tf.experimental.numpy.all(
                            tf.math.equal(tf.cast(tf.reduce_sum(y_true_copy), dtype=tf.float64)
                                       , tf.divide(tf.reduce_sum(y_true_compare_0)
                                                   , tf.size(y_true_copy, out_type=tf.int64)
                                                   , name="div_equals_0_check")
                                      , name="equals_0_check"))
        , concat_0
        , equals_1_check 
        , name="first_condition")

    tf.print("div_equals_0_check shape = ", tf.divide(tf.reduce_sum(y_true_compare_0)
                                           , tf.size(y_true_copy, out_type=tf.int64), name="div_equals_0_check_1"))
    tf.print("div_equals_1_check shape = ",tf.divide(tf.reduce_sum(y_true_compare_1)
                                            , tf.size(y_true_copy, out_type=tf.int64),name="div_equals_1_check_1"))
    if diagnostics: tf.print("after y_true_array =", tf.shape(y_true_array), y_true_array.get_shape())

    error = tf.math.subtract(y_true_array, y_pred, name="error")
    if diagnostics: tf.print("error shape = ", tf.shape(error))
    
    loss = tf.math.reduce_sum(error, name='final_loss')
    if diagnostics: tf.print("loss shape = ", tf.shape(loss), loss)

    return loss

这是我用来初始化 model

model = Model(inputs = inputs, outputs = mobilenet_output, name="main_model")

model.compile(optimizer = 'adam'
              , loss =mobilenet_loss
              , metrics = [ 'accuracy']
              
              )

model.fit(x = train_generator, steps_per_epoch = 1, epochs=2, batch_size=batch_size, shuffle=False)

错误信息:

Epoch 1/2
ytrue = [[2]
 [2]] [2 1]  y_pred shape = [2 31 31 3]
one_class_shape = [2 31 31 1]
y_true_shape -  [2 1]
y_true_compare_0 = [2 1]
y_true_compare_1 = [2 1]
y_true_compare_0 [[0]
 [0]]
div_equals_0_check shape =  0
div_equals_1_check shape =  1
after y_true_array = [2 31 31 3] TensorShape([2, 31, 31, 3])
error shape =  [2 31 31 3]
loss shape =  [] -21462.2773
ytrue = [[2]
 [2]] [2 1]  y_pred shape = [2 31 31 3]
one_class_shape = [2 31 31 1]
y_true_shape -  [2 1]
y_true_compare_0 = [2 1]
y_true_compare_1 = [2 1]
y_true_compare_0 [[0]
 [0]]
div_equals_0_check shape =  0
div_equals_1_check shape =  1
after y_true_array = [2 31 31 3] TensorShape([2, 31, 31, 3])
error shape =  [2 31 31 3]
loss shape =  [] -8.79764557e-05
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-39-739652b805f0> in <module>
----> 1 model.fit(x = train_generator, steps_per_epoch = 1, epochs=2, batch_size=batch_size, shuffle=False)

~/opt/anaconda3/envs/bigthinx/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1098                 _r=1):
   1099               callbacks.on_train_batch_begin(step)
-> 1100               tmp_logs = self.train_function(iterator)
   1101               if data_handler.should_sync:
   1102                 context.async_wait()

~/opt/anaconda3/envs/bigthinx/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    826     tracing_count = self.experimental_get_tracing_count()
    827     with trace.Trace(self._name) as tm:
--> 828       result = self._call(*args, **kwds)
    829       compiler = "xla" if self._experimental_compile else "nonXla"
    830       new_tracing_count = self.experimental_get_tracing_count()

~/opt/anaconda3/envs/bigthinx/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    886         # Lifting succeeded, so variables are initialized and we can run the
    887         # stateless function.
--> 888         return self._stateless_fn(*args, **kwds)
    889     else:
    890       _, _, _, filtered_flat_args = \

~/opt/anaconda3/envs/bigthinx/lib/python3.7/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
   2941        filtered_flat_args) = self._maybe_define_function(args, kwargs)
   2942     return graph_function._call_flat(
-> 2943         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
   2944 
   2945   @property

~/opt/anaconda3/envs/bigthinx/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1917       # No tape is watching; skip to running the function.
   1918       return self._build_call_outputs(self._inference_function.call(
-> 1919           ctx, args, cancellation_manager=cancellation_manager))
   1920     forward_backward = self._select_forward_and_backward_functions(
   1921         args,

~/opt/anaconda3/envs/bigthinx/lib/python3.7/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
    558               inputs=args,
    559               attrs=attrs,
--> 560               ctx=ctx)
    561         else:
    562           outputs = execute.execute_with_cancellation(

~/opt/anaconda3/envs/bigthinx/lib/python3.7/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError:  Incompatible shapes: [2,1] vs. [2,31,31]
     [[node Equal (defined at <ipython-input-39-739652b805f0>:1) ]] [Op:__inference_train_function_26226]

Function call stack:
train_function

奇怪的是,当我使batch_size = 1时,代码工作得很好

这是来自同一 model 的日志,其中 batch_size=1

Epoch 1/2
ytrue = [[2]] [1 1]  y_pred shape = [1 31 31 3]
one_class_shape = [1 31 31 1]
y_true_shape -  [1 1]
y_true_compare_0 = [1 1]
y_true_compare_1 = [1 1]
y_true_compare_0 [[0]]
div_equals_0_check shape =  0
div_equals_1_check shape =  1
after y_true_array = [1 31 31 3] TensorShape([1, 31, 31, 3])
error shape =  [1 31 31 3]
loss shape =  [] -1222.78308
ytrue = [[2]] [1 1]  y_pred shape = [1 31 31 3]
one_class_shape = [1 31 31 1]
y_true_shape -  [1 1]
y_true_compare_0 = [1 1]
y_true_compare_1 = [1 1]
y_true_compare_0 [[0]]
div_equals_0_check shape =  0
div_equals_1_check shape =  1
after y_true_array = [1 31 31 3] TensorShape([1, 31, 31, 3])
error shape =  [1 31 31 3]
loss shape =  [] -8.76188278e-06
1/1 [==============================] - 15s 15s/step - loss: -1222.7831 - mobnet_features_loss: -1222.7831 - Predictions_loss: -8.7619e-06 - mobnet_features_accuracy: 0.3777 - Predictions_accuracy: 0.3777
Epoch 2/2
ytrue = [[2]] [1 1]  y_pred shape = [1 31 31 3]
one_class_shape = [1 31 31 1]
y_true_shape -  [1 1]
y_true_compare_0 = [1 1]
y_true_compare_1 = [1 1]
y_true_compare_0 [[0]]
div_equals_0_check shape =  0
div_equals_1_check shape =  1
after y_true_array = [1 31 31 3] TensorShape([1, 31, 31, 3])
error shape =  [1 31 31 3]
loss shape =  [] -4929.66
ytrue = [[2]] [1 1]  y_pred shape = [1 31 31 3]
one_class_shape = [1 31 31 1]
y_true_shape -  [1 1]
y_true_compare_0 = [1 1]
y_true_compare_1 = [1 1]
y_true_compare_0 [[0]]
div_equals_0_check shape =  0
div_equals_1_check shape =  1
after y_true_array = [1 31 31 3] TensorShape([1, 31, 31, 3])
error shape =  [1 31 31 3]
loss shape =  [] 4.91845421e-05
1/1 [==============================] - 12s 12s/step - loss: -4929.6602 - mobnet_features_loss: -4929.6602 - Predictions_loss: 4.9185e-05 - mobnet_features_accuracy: 0.3049 - Predictions_accuracy: 0.3049

找到了我在这里发布的问题的解决方案。

  1. 我将损失重塑为形状 [batch, final_conv_width, final_conv_height, feature_classes]
  2. 我使用 tf.GradientTape 和 optimizer.apply_gradients 函数创建了一个自定义训练循环。
@tf.function
def grad(model, inputs, targets):
    with tf.GradientTape() as tape:
        y_pred = model(x)
        tf.print("in grad function y_pred shape = ", tf.type_spec_from_value(y_pred))
        loss_value = mobilenet_loss(targets, y_pred)
    return loss_value, tape.gradient(loss_value, model.trainable_variables)

num_epochs = 2
optimizer = tf.keras.optimizers.Adam()
avg_loss = tf.keras.metrics.Mean()

for epoch in range(num_epochs):
    x, y = next(train_generator)
    # Optimize the model

    loss_value, gradients = grad(model, x, y)

    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    avg_loss(loss_value)
    print(" epoch %d/%d, loss=%.4f "%(epoch +1, num_epochs, avg_loss.result()))
                                  

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM