简体   繁体   中英

tensorflow: gradients for a custom loss function

I have an LSTM predicting time series values in tensorflow. The model is working using an MSE as a loss function. However, I'd like to be able to create a custom loss function where one of the error values is multiplied by two (therefore producing a higher error value).

In my batch of size 10, I want the 3rd value of the first input to be multiplied by 2, but because this is time series, this corresponds to the second value in the second input and the first value in the third input.

The error I get is: ValueError: No gradients provided for any variable, check your graph for ops that do not support gradients

How do I make the gradients?

def loss_function(y_true, y_pred, peak_value=3, weight=2):
# peak value is where the multiplication happens on the first line
# weight is the how much the error is multiplied by

    all_dif = tf.squared_difference(y_true, y_pred)  # should be shape=[10,10]

    peak = [peak_value] * 10

    listy = range(0, 10)
    c = [(i - j) % 10 for i, j in zip(peak, listy)]
    for i in range(0, 10):
        indices = [[i, c[i]]]
        values = [1.0]
        shape = [10,10]
        delta = tf.SparseTensor(indices, values, shape)
        all_dif = all_dif + tf.sparse_tensor_to_dense(delta)
    return tf.reduce_sum(all_dif)

I believe the psuedo code would look something like this:

@tf.custom_gradient
def loss_function(y_true, y_pred, peak_value=3, weight=2)
    ## your code
    def grad(dy):
        return dy * partial_derivative
    return loss, grad

Where partial_derivative is the analytically evaluated partial derivative with respect to your loss function. If your loss function is a function of more than one variable, it will require a partial derivative respect to each variable, I believe.

If you need more information, the documentation is good: https://www.tensorflow.org/api_docs/python/tf/custom_gradient

And I've yet to find an example of this functionality embedded in a model that's not a toy.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM