![](/img/trans.png)
[英]Advice on how to create a custom tf.keras optimizer (optimizer_v2)
[英]How can I create a custom keras optimizer?
我正在比较 SVRG、SAG 和其他优化器在深度学习最小化方面的性能。
如何使用 keras 实现自定义优化器,我尝试在此处查看 SGD keras 实现源代码,但找不到tf.raw_ops.ResourceApplyGradientDescent
的源代码,这使得很难为另一个优化器重现。
要自定义优化器:
tf.keras.optimizers.Optimizer
。_create_slots
:这用于为每个可训练变量创建优化器变量。 如果您需要为优化器添加动量,这将很有用。_resource_apply_dense
或_resource_apply_sparse
以执行优化器的实际更新和方程。get_config
(可选):存储您传递给优化器的参数,以便您可以克隆,或者之后保存您的 model。这是 SGD 的一个简单示例,动量取自此处
class MyMomentumOptimizer(keras.optimizers.Optimizer):
def __init__(self, learning_rate=0.001, momentum=0.9, name="MyMomentumOptimizer", **kwargs):
"""Call super().__init__() and use _set_hyper() to store hyperparameters"""
super().__init__(name, **kwargs)
self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) # handle lr=learning_rate
self._set_hyper("decay", self._initial_decay) #
self._set_hyper("momentum", momentum)
def _create_slots(self, var_list):
"""For each model variable, create the optimizer variable associated with it.
TensorFlow calls these optimizer variables "slots".
For momentum optimization, we need one momentum slot per model variable.
"""
for var in var_list:
self.add_slot(var, "momentum")
@tf.function
def _resource_apply_dense(self, grad, var):
"""Update the slots and perform one optimization step for one model variable
"""
var_dtype = var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype) # handle learning rate decay
momentum_var = self.get_slot(var, "momentum")
momentum_hyper = self._get_hyper("momentum", var_dtype)
momentum_var.assign(momentum_var * momentum_hyper - (1. - momentum_hyper)* grad)
var.assign_add(momentum_var * lr_t)
def _resource_apply_sparse(self, grad, var):
raise NotImplementedError
def get_config(self):
base_config = super().get_config()
return {
**base_config,
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"decay": self._serialize_hyperparameter("decay"),
"momentum": self._serialize_hyperparameter("momentum"),
}
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.