簡體   English   中英

如何防止 Keras 在訓練期間計算指標

[英]How to prevent Keras from computing metrics during training

我正在使用 Tensorflow/Keras 2.4.1,並且我有一個(無監督的)自定義指標,它將我的幾個模型輸入作為參數,例如:

model = build_model() # returns a tf.keras.Model object
my_metric = custom_metric(model.output, model.input[0], model.input[1])
model.add_metric(my_metric)
[...]
model.fit([...]) # training with fit

但是,恰好custom_metric非常昂貴,因此我希望僅在驗證期間對其進行計算。 我找到了這個答案,但我幾乎不明白如何使解決方案適應使用多個模型輸入作為參數的指標,因為update_state方法似乎不靈活。

在我的上下文中,除了編寫我自己的訓練循環之外,有沒有辦法避免在訓練期間計算我的指標? 此外,我很驚訝我們不能在本地指定 Tensorflow 某些指標只能在驗證時計算,這有什么原因嗎?

此外,由於模型是為了優化損失而訓練的,而且訓練數據集不應該用於評估模型,我什至不明白為什么默認情況下 Tensorflow 在訓練期間計算指標。

我認為僅在驗證時計算指標的最簡單解決方案是使用自定義回調。

在這里我們定義我們的虛擬回調:

class MyCustomMetricCallback(tf.keras.callbacks.Callback):

    def __init__(self, train=None, validation=None):
        super(MyCustomMetricCallback, self).__init__()
        self.train = train
        self.validation = validation

    def on_epoch_end(self, epoch, logs={}):

        mse = tf.keras.losses.mean_squared_error

        if self.train:
            logs['my_metric_train'] = float('inf')
            X_train, y_train = self.train[0], self.train[1]
            y_pred = self.model.predict(X_train)
            score = mse(y_train, y_pred)
            logs['my_metric_train'] = np.round(score, 5)

        if self.validation:
            logs['my_metric_val'] = float('inf')
            X_valid, y_valid = self.validation[0], self.validation[1]
            y_pred = self.model.predict(X_valid)
            val_score = mse(y_pred, y_valid)
            logs['my_metric_val'] = np.round(val_score, 5)

鑒於這個虛擬模型:

def build_model():

  inp1 = Input((5,))
  inp2 = Input((5,))
  out = Concatenate()([inp1, inp2])
  out = Dense(1)(out)

  model = Model([inp1, inp2], out)
  model.compile(loss='mse', optimizer='adam')

  return model

這個數據:

X_train1 = np.random.uniform(0,1, (100,5))
X_train2 = np.random.uniform(0,1, (100,5))
y_train = np.random.uniform(0,1, (100,1))

X_val1 = np.random.uniform(0,1, (100,5))
X_val2 = np.random.uniform(0,1, (100,5))
y_val = np.random.uniform(0,1, (100,1))

您可以使用自定義回調來計算訓練和驗證的指標:

model = build_model()

model.fit([X_train1, X_train2], y_train, epochs=10, 
          callbacks=[MyCustomMetricCallback(train=([X_train1, X_train2],y_train), validation=([X_val1, X_val2],y_val))])

僅在驗證時:

model = build_model()

model.fit([X_train1, X_train2], y_train, epochs=10, 
          callbacks=[MyCustomMetricCallback(validation=([X_val1, X_val2],y_val))])

僅在火車上:

model = build_model()

model.fit([X_train1, X_train2], y_train, epochs=10, 
          callbacks=[MyCustomMetricCallback(train=([X_train1, X_train2],y_train))])

請記住,回調會一次性評估數據上的指標,就像 keras 在validation_data上默認計算的任何指標/損失一樣。

是運行代碼。

我能夠使用learning_phase但只能在符號張量模式(圖形)模式下使用:

因此,首先我們需要禁用 Eager 模式(這必須在導入 tensorflow 后立即完成):

import tensorflow as tf
tf.compat.v1.disable_eager_execution()

然后,您可以使用符號 if ( backend.switch ) 創建指標:

def metric_graph(in1, in2, out):
    actual_metric = out * (in1 + in2)
    return K.switch(K.learning_phase(), tf.zeros((1,)), actual_metric) 

方法add_metric將詢問名稱和聚合方法,您可以將其設置為"mean"

所以,這里有一個例子:

x1 = numpy.ones((5,3))
x2 = numpy.ones((5,3))
y = 3*numpy.ones((5,1))

vx1 = numpy.ones((5,3))
vx2 = numpy.ones((5,3))
vy = 3*numpy.ones((5,1))

def metric_eager(in1, in2, out):
    if (K.learning_phase()):
        return 0
    else:
        return out * (in1 + in2)

def metric_graph(in1, in2, out):
    actual_metric = out * (in1 + in2)
    return K.switch(K.learning_phase(), tf.zeros((1,)), actual_metric) 

ins1 = Input((3,))
ins2 = Input((3,))
outs = Concatenate()([ins1, ins2])
outs = Dense(1)(outs)
model = Model([ins1, ins2],outs)
model.add_metric(metric_graph(ins1, ins2, outs), name='my_metric', aggregation='mean')
model.compile(loss='mse', optimizer='adam')

model.fit([x1, x2],y, validation_data=([vx1, vx2], vy), epochs=3)

由於指標是在keras.Modeltrain_step函數中運行的,因此在不更改 API 的情況下過濾掉訓練禁用指標需要對keras.Model進行子類化。

我們定義了一個簡單的度量包裝器:

class TrainDisabledMetric(Metric):

  def __init__(self, metric: Metric):
    super().__init__(name=metric.name)
    self._metric = metric

  def update_state(self, *args, **kwargs):
    return self._metric.update_state(*args, **kwargs)

  def reset_state(self):
    return self._metric.reset_state()

  def result(self):
    return self._metric.result()

和子類keras.Model在訓練期間過濾掉這些指標:

class CustomModel(keras.Model):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)

  def compile(self, optimizer='rmsprop', loss=None, metrics=None,
              loss_weights=None, weighted_metrics=None, run_eagerly=None,
              steps_per_execution=None, jit_compile=None, **kwargs):

    from_serialized = kwargs.get('from_serialized', False)

    super().compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights,
                    weighted_metrics=weighted_metrics, run_eagerly=run_eagerly,
                    steps_per_execution=steps_per_execution,
                    jit_compile=jit_compile, **kwargs)

    self.on_train_compiled_metrics = self.compiled_metrics

    if metrics is not None:

      def get_on_train_traverse_tree(structure):
        flat = tf.nest.flatten(structure)
        on_train = [not isinstance(e, TrainDisabledMetric) for e in flat]
        full_tree = tf.nest.pack_sequence_as(structure, on_train)
        return get_traverse_shallow_structure(lambda s: any(tf.nest.flatten(s)),
                                              full_tree)

      on_train_sub_tree = get_on_train_traverse_tree(metrics)
      flat_on_train = flatten_up_to(on_train_sub_tree, metrics)

      def clean_tree(tree):
        if isinstance(tree, list):
          _list = []
          for t in tree:
            r = clean_tree(t)
            if r:
              _list.append(r)
          return _list

        elif isinstance(tree, dict):
          _tree = {}
          for k, v in tree.items():
            r = clean_tree(v)
            if r:
              _tree[k] = r
          return _tree
        else:
          return tree

      pruned_on_train_sub_tree = clean_tree(on_train_sub_tree)
      pruned_flat_on_train = [m for keep, m in
                              zip(tf.nest.flatten(on_train_sub_tree),
                                  flat_on_train) if keep]

      on_train_metrics = tf.nest.pack_sequence_as(pruned_on_train_sub_tree,
                                                  pruned_flat_on_train)

      self.on_train_compiled_metrics = compile_utils.MetricsContainer(
        on_train_metrics, weighted_metrics=None, output_names=self.output_names,
        from_serialized=from_serialized)

  def train_step(self, data):
    x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
    # Run forward pass.
    with tf.GradientTape() as tape:
      y_pred = self(x, training=True)
      loss = self.compute_loss(x, y, y_pred, sample_weight)
    self._validate_target_and_loss(y, loss)
    # Run backwards pass.
    self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    return self.compute_metrics(x, y, y_pred, sample_weight, training=True)

  def compute_metrics(self, x, y, y_pred, sample_weight, training=False):
    del x  # The default implementation does not use `x`.

    if training:
      self.on_train_compiled_metrics.update_state(y, y_pred, sample_weight)
      metrics = self.on_train_metrics
    else:
      self.compiled_metrics.update_state(y, y_pred, sample_weight)
      metrics = self.metrics
    # Collect metrics to return
    return_metrics = {}
    for metric in metrics:
      result = metric.result()
      if isinstance(result, dict):
        return_metrics.update(result)
      else:
        return_metrics[metric.name] = result
    return return_metrics

  @property
  def on_train_metrics(self):
    metrics = []
    if self._is_compiled:
      # TODO(omalleyt): Track `LossesContainer` and `MetricsContainer` objects
      # so that attr names are not load-bearing.
      if self.compiled_loss is not None:
        metrics += self.compiled_loss.metrics
      if self.on_train_compiled_metrics is not None:
        metrics += self.on_train_compiled_metrics.metrics

    for l in self._flatten_layers():
      metrics.extend(l._metrics)  # pylint: disable=protected-access
    return metrics

現在給定一個 keras 模型,我們可以包裝它並使用禁用訓練的指標對其進行編譯:

model: keras.Model = ...
custom_model = CustomModel(inputs=model.input, outputs=model.output)

train_enabled_metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]

# wrap train disabled metrics with `TrainDisabledMetric`:
train_disabled_metrics = [
  TrainDisabledMetric(tf.keras.metrics.SparseCategoricalCrossentropy())]

metrics = train_enabled_metrics + train_disabled_metrics

custom_model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
                     loss=tf.keras.losses.SparseCategoricalCrossentropy(
                       from_logits=True), metrics=metrics, )

custom_model.fit(ds_train, epochs=6, validation_data=ds_test, )

度量SparseCategoricalCrossentropy僅在驗證期間計算:

Epoch 1/6
469/469 [==============================] - 2s 2ms/step - loss: 0.3522 - sparse_categorical_accuracy: 0.8366 - val_loss: 0.1978 - val_sparse_categorical_accuracy: 0.9086 - val_sparse_categorical_crossentropy: 1.3197
Epoch 2/6
469/469 [==============================] - 1s 1ms/step - loss: 0.1631 - sparse_categorical_accuracy: 0.9526 - val_loss: 0.1429 - val_sparse_categorical_accuracy: 0.9587 - val_sparse_categorical_crossentropy: 1.1910
Epoch 3/6
469/469 [==============================] - 1s 1ms/step - loss: 0.1178 - sparse_categorical_accuracy: 0.9654 - val_loss: 0.1139 - val_sparse_categorical_accuracy: 0.9661 - val_sparse_categorical_crossentropy: 1.1369
Epoch 4/6
469/469 [==============================] - 1s 1ms/step - loss: 0.0909 - sparse_categorical_accuracy: 0.9735 - val_loss: 0.0981 - val_sparse_categorical_accuracy: 0.9715 - val_sparse_categorical_crossentropy: 1.0434
Epoch 5/6
469/469 [==============================] - 1s 1ms/step - loss: 0.0735 - sparse_categorical_accuracy: 0.9784 - val_loss: 0.0913 - val_sparse_categorical_accuracy: 0.9721 - val_sparse_categorical_crossentropy: 0.9862
Epoch 6/6
469/469 [==============================] - 1s 1ms/step - loss: 0.0606 - sparse_categorical_accuracy: 0.9823 - val_loss: 0.0824 - val_sparse_categorical_accuracy: 0.9761 - val_sparse_categorical_crossentropy: 1.0024


暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM