繁体   English   中英

如何创建自定义 keras 优化器?

[英]How can I create a custom keras optimizer?

我正在比较 SVRG、SAG 和其他优化器在深度学习最小化方面的性能。

如何使用 keras 实现自定义优化器,我尝试在此处查看 SGD keras 实现源代码,但找不到tf.raw_ops.ResourceApplyGradientDescent的源代码,这使得很难为另一个优化器重现。

要自定义优化器:

  • 扩展tf.keras.optimizers.Optimizer
  • 覆盖_create_slots :这用于为每个可训练变量创建优化器变量。 如果您需要为优化器添加动量,这将很有用。
  • 覆盖_resource_apply_dense_resource_apply_sparse以执行优化器的实际更新和方程。
  • get_config (可选):存储您传递给优化器的参数,以便您可以克隆,或者之后保存您的 model。

这是 SGD 的一个简单示例,动量取自此处

class MyMomentumOptimizer(keras.optimizers.Optimizer):
    def __init__(self, learning_rate=0.001, momentum=0.9, name="MyMomentumOptimizer", **kwargs):
        """Call super().__init__() and use _set_hyper() to store hyperparameters"""
        super().__init__(name, **kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) # handle lr=learning_rate
        self._set_hyper("decay", self._initial_decay) # 
        self._set_hyper("momentum", momentum)
    
    def _create_slots(self, var_list):
        """For each model variable, create the optimizer variable associated with it.
        TensorFlow calls these optimizer variables "slots".
        For momentum optimization, we need one momentum slot per model variable.
        """
        for var in var_list:
            self.add_slot(var, "momentum")

    @tf.function
    def _resource_apply_dense(self, grad, var):
        """Update the slots and perform one optimization step for one model variable
        """
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype) # handle learning rate decay
        momentum_var = self.get_slot(var, "momentum")
        momentum_hyper = self._get_hyper("momentum", var_dtype)
        momentum_var.assign(momentum_var * momentum_hyper - (1. - momentum_hyper)* grad)
        var.assign_add(momentum_var * lr_t)

    def _resource_apply_sparse(self, grad, var):
        raise NotImplementedError

    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            "learning_rate": self._serialize_hyperparameter("learning_rate"),
            "decay": self._serialize_hyperparameter("decay"),
            "momentum": self._serialize_hyperparameter("momentum"),
        }

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM