简体   繁体   English

如何在 Keras/TensorFlow 的自定义层中应用 kernel 正则化?

[英]How to apply kernel regularization in a custom layer in Keras/TensorFlow?

Consider the following custom layer code from a TensorFlow tutorial:考虑 TensorFlow 教程中的以下自定义层代码:

class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.kernel = self.add_weight("kernel",
                                  shape=[int(input_shape[-1]),
                                         self.num_outputs])

  def call(self, input):
    return tf.matmul(input, self.kernel)

How do I apply any pre-defined regularization (say tf.keras.regularizers.L1 ) or custom regularization on the parameters of the custom layer?如何对自定义层的参数应用任何预定义的正则化(例如tf.keras.regularizers.L1 )或自定义正则化?

The add_weight method takes a regularizer argument which you can use to apply regularization on the weight. add_weight方法接受一个regularizer参数,您可以使用该参数对权重应用正则化。 For example:例如:

self.kernel = self.add_weight("kernel",
                               shape=[int(input_shape[-1]), self.num_outputs],
                               regularizer=tf.keras.regularizers.l1_l2())

Alternatively, to have more control like other built-in layers, you can modify the definition of custom layer and add a kernel_regularizer argument to __init__ method:或者,要像其他内置层一样拥有更多控制,您可以修改自定义层的定义并将kernel_regularizer参数添加到__init__方法:

from tensorflow.keras import regularizers

class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs, kernel_regularizer=None):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs
    self.kernel_regularizer = regularizers.get(kernel_regularizer)

  def build(self, input_shape):
    self.kernel = self.add_weight("kernel",
                                  shape=[int(input_shape[-1]), self.num_outputs],
                                  regularizer=self.kernel_regularizer)

With that you can even pass a string like 'l1' or 'l2' to kernel_regularizer argument when constructing the layer, and it would be resolved properly.这样,您甚至可以在构造层时将像'l1''l2'这样的字符串传递给kernel_regularizer参数,它会被正确解析。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM