繁体   English   中英

如何使用 Tensorflow 2 在我的自定义优化器上更新可训练变量

[英]How to Update Trainable Variables on My Custom Optimizer Using Tensorflow 2

我现在是学习卷积神经网络的新手。

因此,我一直在实施 AlexNet,参考了一篇名为“ImageNet Classification with Deep Convolution Neural Networks”的论文,该论文在 anaconda 环境中使用了 Tensorflow 2.3。 但是,我在实施自定义优化器时感到沮丧。

我的问题:我必须根据 AlexNet 论文修改优化器。 我找不到参考如何更新使用 TensorflowV2 的变量,但我已经谷歌搜索。 只有 "tf.assign()" 使用,Tensorflow V2 不支持它,但如果我要使用此功能,我也担心 V1 和 V2 之间的兼容性。

我只知道我必须自定义“_resource_apply_dense”函数以适应我的更新规则。 然后,我在那里加载了一些超参数。 但我不知道如何更新超参数。

(tf.Variable() 可以作为 python 变量使用,所以我想它和 tf.Variable(....?) 是一样的

谢谢各位读者的进阶^_^

这是代码。

- update rule in AlexNet
 v_(i+1)= momentum * v_i- 0.0005 * learning_rate * w_i - learning_rate * gradient
 w_(i+1) = w_i + v_(i+1)

# where
# w_i = weight
# v_i = velocity
from tensorflow.python.framework import ops
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training import training_ops
from tensorflow.python.util.tf_export import keras_export
import tensorflow as tf

class AlexSGD(optimizer_v2.OptimizerV2):

    _HAS_AGGREGATE_GRAD = True

    def __init__(self,
                learning_rate=0.01,
                weight_decay=0.0005,
                momentum=0.9,
                name="AlexSGD",
                **kwargs):
        super(AlexSGD, self).__init__(name, **kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
        self._set_hyper("decay", self._initial_decay)
        
        self._is_first = True
        self._set_hyper("vx", 0)
        self._set_hyper("pg", 0)
        self._set_hyper("pv", 0)
        self._weight_decay = False
        if isinstance(weight_decay, ops.Tensor) or callable(weight_decay) or 
            weight_decay > 0:
        self._weight_decay = True
        if isinstance(weight_decay, (int, float)) and (weight_decay < 0 or 
            weight_decay > 1):
        raise ValueError("`weight_decay` must be between [0, 1].")
        self._set_hyper("weight_decay", weight_decay)

        self._momentum = False
        if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0:
        self._momentum = True
        if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1):
        raise ValueError("`momentum` must be between [0, 1].")
        self._set_hyper("momentum", momentum)

    def _create_slots(self, var_list):
        if self._momentum:
        for var in var_list:
            self.add_slot(var, "momentum")
        if self._weight_decay:
        for var in var_list:
            self.add_slot(var, "weight_decay")
        for var in var_list:
        self.add_slot(var, 'pv') # previous variable i.e. weight or bias    
        for var in var_list:
        self.add_slot(var, 'pg') # previous gradient
        for var in var_list:
        self.add_slot(var, 'vx') # update velocity

    def _prepare_local(self, var_device, var_dtype, apply_state):
        super(AlexSGD, self)._prepare_local(var_device, var_dtype, apply_state)
        apply_state[(var_device, var_dtype)]["momentum"] = array_ops.identity(
            self._get_hyper("momentum", var_dtype))
        apply_state[(var_device, var_dtype)]["weight_decay"] = array_ops.identity(
            self._get_hyper("weight_decay", var_dtype))
        apply_state[(var_device, var_dtype)]["vx"] = array_ops.identity(
            self._get_hyper("vx", var_dtype))
        apply_state[(var_device, var_dtype)]["pv"] = array_ops.identity(
            self._get_hyper("pv", var_dtype))
        apply_state[(var_device, var_dtype)]["pg"] = array_ops.identity(
            self._get_hyper("pg", var_dtype))

    # main function
    @tf.function
    def _resource_apply_dense(self, grad, var, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        coefficients = ((apply_state or {}).get((var_device, var_dtype))
                        or self._fallback_apply_state(var_device, var_dtype))
        momentum_var = self.get_slot(var, "momentum")
        weight_decay_var = self.get_slot(var, "weight_decay")
        vx_var = self.get_slot(var, "vx")
        pv_var = self.get_slot(var, "pv")
        pg_var = self.get_slot(var, "pg")
        lr_t = self._decayed_lr(var_dtype)

        # update rule in AlexNet
        # v_(i+1) = momentum * v_i - 0.0005 * lr * w_i - lr * grad
        # w_(i+1) = w_i + v_(i+1)
        # where
        # w_i = var
        # vx, v_i = velocity (Feel like I need to set this variable as a slot) 
        # lr = learning_rate
        # grad = gradient
        
        # I'm confused why pv, pg variables are declared... 
        # does it replace by var & grad ?  (pv, pg refer from blog)
        # pv = previous var
        # pg = previous gradient

        if self._is_first:
        self._is_first = False
        vx_var = grad
        new_var = var + vx_var
        else:
        vx_var = momentum_var * vx_var - weight_decay_var*lr_t*pv_var- 
        lr_t*pg_var
        new_var = var + vx_var

        print("grad:",grad)
        print("var:",var)
        print("vx_var:",vx_var)
        print("new_var:",new_var)
        print("pv_var:",pv_var)
        print("pg_var:",pg_var)
        
        # TODO: I got stuck how update the variables because tf.assign() function 
        #       is deprecated in Tensorflow V2 
        pg_var = grad
        pv_var = var

        if var == new_var:
        var = new_var
        
        # TODO: In order to update variables, I can't find the equivalent 
        #        "tf.assign" method in TF V2
        # pg_var.assign(grad)
        # vx_var.assign(vx_var)
        # var.assign(new_var)
        
        
        
        """
        # TODO: I referred the below code from Tensorflow official document, and I 
        #       realized the training_ops module is in c++ library, So I thought I 
        #       can't modify it ( Cuz I need to modify an update function of 
        #       velocity 

        # return training_ops.resource_apply_keras_momentum(
        #     var.handle,
        #     momentum_var.handle,
        #     coefficients["lr_t"],
        #     grad,
        #     coefficients["momentum"],
        #     use_locking=self._use_locking,
        #     use_nesterov=self.nesterov)


        # if self._momentum :
        #  momentum_var = self.get_slot(var, "momentum")
        #   return training_ops.resource_apply_keras_momentum(
        #     var.handle,
        #     momentum_var.handle,
        #     coefficients["lr_t"],
        #     grad,
        #     coefficients["momentum"],
        #     use_locking=self._use_locking,
        #     use_nesterov=self.nesterov)
        # else:
        #   return training_ops.resource_apply_gradient_descent(
        #       var.handle, coefficients["lr_t"], grad, 
                use_locking=self._use_locking)
        """


    def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
        raise NotImplementedError

    def get_config(self):
        config = super(AlexSGD, self).get_config()
        config.update({
            "learning_rate": self._serialize_hyperparameter("learning_rate"),
            "decay": self._serialize_hyperparameter("decay"),
            "weight_decay": self._serialize_hyperparameter("weight_decay"),
            "momentum": self._serialize_hyperparameter("momentum"),
        })
        return config

参考: https : //www.kdnuggets.com/2018/01/custom-optimizer-tensorflow.html

这是另一个参考,它也使用 tf.assign()

from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.framework import ops
from tensorflow.python.training import optimizer
import tensorflow as tf

class AlexOptimizer(optimizer.Optimizer):
    def __init__(self, learning_rate="learning_rate",alpha="alpha",beta="beta", #weight_decay="weight_decay", use_locking=False, name="AlexOptimizer"):
        super(AlexOptimizer, self).__init__(use_locking, name)
        self._lr = learning_rate
        self._wd = weight_decay
        self._alpha = alpha
        self._beta = beta
        # Tensor versions of the constructor arguments, created in _prepare().
        self._lr_t = None
        self._wd_t = None
        self._alpha_t = None
        self._beta_t = None

    def _prepare(self):
        self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate")
        self._wd_t = ops.convert_to_tensor(self._wd, name="weight_decay")
        self._alpha_t = ops.convert_to_tensor(self._beta, name="alpha_t")
        self._beta_t = ops.convert_to_tensor(self._beta, name="beta_t")

    def _create_slots(self, var_list):
        # Create slots for the first and second moments.
        for v in var_list:
            self._zeros_slot(v, "m", self._name)

    def _apply_dense(self, grad, var):
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        wd_t = math_ops.cast(self._wd_t, var.dtype.base_dtype)
        alpha_t = math_ops.cast(self._alpha_t, var.dtype.base_dtype)
        beta_t = math_ops.cast(self._beta_t, var.dtype.base_dtype)

        eps = 1e-7 #cap for moving average
        m = self.get_slot(var, "m")
        m_t = m.assign(tf.maximum(beta_t * m + eps, tf.abs(grad)))

        var_update = state_ops.assign_sub(var, lr_t*grad*tf.exp( tf.log(alpha_t)*tf.sign(grad)*tf.sign(m_t))) 
        # Update 'ref' by subtracting value
        # Create an op that groups multiple operations.
        # When this op finishes, all ops in input have finished
        return control_flow_ops.group(*[var_update, m_t])
    def _apply_sparse(self, grad, var):
        raise NotImplementedError("Sparse gradient updates are not supported.")

我想使用OptimizerV2修改代码,我应该更新什么变量?

(ps)“def _resource_apply_dense()”上方的“@tf.function”用法是否正确?

另一方面,我的模型在训练ㅠ_ㅠ(此代码在数据集预处理过程中(tf.data.datsets.shuffle())期间不断地改组,即使它在while循环中不存在...... .......(抱歉没有发布此代码......所以没关系......)

不需要v1 tf.assign v2assign是一个Variable类方法

你的第二个例子不使用tf.assign

m_t = m.assign(tf.maximum(beta_t * m + eps, tf.abs(grad)))

v2 它只是从v2 API 调用类方法。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM