简体   繁体   English

具有高级API的张量流中的L2正则化

[英]L2 regularization in tensorflow with high level API

I know there are some similar questions out there regarding l2 regularization with the layer API from tensorflow but it is still not quite clear to me. 我知道有一些关于使用tensorflow的层API进行l2正则化的类似问题,但对我来说仍然不太清楚。

So first I set the kernel_regularizer in my conv2d layers repeatedly like this: 首先,我在我的conv2d层中重复设置kernel_regularizer ,如下所示:

regularizer = tf.contrib.layers.l2_regularizer(scale=0.1)
tf.layers.conv2d(kernel_regularizer=)

Then I can collect all the regularization losses with following: 然后我可以通过以下方式收集所有正规化损失:

regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)

And last but not least I have to incorporate the regularization term into the final loss. 最后但并非最不重要的是,我必须将正则化术语纳入最终损失。 However, here I am not quite sure what to do, which one of the following is correct? 但是,在这里我不太清楚该怎么做,以下哪一项是正确的?

1) loss = loss + factor * tf.reduce_sum(regularization_losses)

2) loss = loss + tf.contrib.layers.apply_regularization(regularizer, weights_list=regularization_losses)

Or are both of them wrong? 或者他们都错了? The second option seems weird to mean since I have to pass the regularizer as parameter once again, even tho each layer already has a regularizer as argument. 第二个选项似乎很奇怪,因为我必须再次将正则化器作为参数传递,即使每个层已经有一个正则化器作为参数。

EDIT 编辑

loss_1 = tf.losses.mean_squared_error(labels=y, predictions=logits, weights=1000)

regularization_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)

loss = tf.add_n([loss_1] + regularization_loss, name='loss')

The first method is correct. 第一种方法是正确的。 One more way of doing this is via tf.add_n function: tf.add_n一种方法是通过tf.add_n函数:

reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
loss = tf.add_n([base_loss] + reg_losses, name="loss")

The second method also works, but you'll have to define a single regularizer. 第二种方法也可以,但是你必须定义一个正则。 So it works in your case, but may be inconvenient if you use different regularizers in different layers. 所以它适用于您的情况,但如果您在不同的层中使用不同的正则化器,可能会很不方便。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM