简体   繁体   English

Keras,循环输出的成本函数?

[英]Keras, cost function for cyclic outputs?

Right now I am trying to get a neural net to colorize images. 现在,我正在尝试使用神经网络对图像进行着色。 I want to do it in the HSV color space. 我想在HSV颜色空间中执行此操作。 The issue with this is that the hue channel is cyclic. 这样做的问题是色调通道是循环的。 The normalized values for hue are between 0 and 1. Say for instance the model predicts 0.99 but the actual hue is 0.01. 色相的归一化值介于0和1之间。例如,模型预测为0.99,但实际色相为0.01。 With normal mean squared error loss this looks like it is way off. 在具有正常均方误差损失的情况下,这似乎已经过去了。 However the distance is really more like 0.02. 但是,距离实际上更像是0.02。 How can I get a cyclic loss function in keras? 如何获得喀拉拉邦的周期性损失函数?

The true distance from predicted hue A to actual hue B is really the minimum of 3 terms: 从预测的色相A到实际的色相B的真实距离实际上至少是3个项:

  1. (A - B)^2 (distance if you don't wrap around) (A - B)^2 (距离不远的话)
  2. (A - B + 1)^2 (distance if you wrap around to the left) (A - B + 1)^2 (如果绕到左侧,则为距离)
  3. (A - B - 1)^2 (distance if you wrap around to the right) (A - B - 1)^2 (距离,如果环绕在右边)

For instance, in your example the shortest way to get from A = 0.99 to B = 0.01 is to wrap around to the right, and the distance is (A - B - 1)^2 = (0.99 - 0.01 - 1)^2 = (-0.02)^2 = 0.02^2 . 例如,在您的示例中,从A = 0.99B = 0.01的最短方法是环绕在右边,距离为(A - B - 1)^2 = (0.99 - 0.01 - 1)^2 = (-0.02)^2 = 0.02^2

Now that we have the math figured out, how do we implement it? 现在我们已经弄清楚了数学,我们如何实现它? Keras's implementation of mean squared error is: Keras对均方误差的实现是:

from keras import backend as K

def mean_squared_error(y_true, y_pred):
    return K.mean(K.square(y_pred - y_true), axis=-1)

Here's the tweak to make it cyclic: 这是使其变为周期性的调整:

def cyclic_mean_squared_error(y_true, y_pred):
    return K.mean(K.minimum(K.square(y_pred - y_true), 
                            K.minimum(K.square(y_pred - y_true + 1), 
                                      K.square(y_pred - y_true - 1)), axis=-1)

To use this loss function, specify loss=cyclic_mean_squared_error when compiling the model. 要使用此损失函数, loss=cyclic_mean_squared_error在编译模型时指定loss=cyclic_mean_squared_error

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM