簡體   English   中英

如何根據現有Tensorflow操作組成來創建新的Tensorflow操作

[英]How to create new Tensorflow op from composition of existing Tensorflow ops

我知道如何使用tf.py_func創建在CPU上運行的新自定義操作 我也從TF指南知道,您可以在C ++中創建一個新的op及其漸變

我正在尋找的不是以上所有內容。 我想為TF ops的組合定義自定義漸變函數。 tf.register_gradients可以與gradient_override_map一起使用,以為現有操作定義自定義漸變,但是首先如何將TF操作的組成注冊為新操作?

這里有人提出類似的問題,但沒有答案。

tfe.custom_gradient是您要使用的裝飾器

我在倉庫中提供了三種在Tensorflow中定義自定義漸變的不同方法。

custom_gradient_with_py_func

在這種方法中,我們使用tf.py_func定義了tf op,並為其分配了自定義漸變函數。

with g.gradient_override_map({"PyFunc": rnd_name}):
    return tf.py_func(func, inp, Tout, stateful=stateful, name=name)

custom_gradient_with_python:

在這種方法中,我們使用一種變通辦法為Tensorflow ops的組合定義自定義漸變。 我們忽略了身份運算的梯度。

def python_func(x_in, name=None):
    with ops.name_scope(name):
        backward_func = tf.identity(x_in) # We'll later override the gradient of identity to deflect our desired gradient function.
        forward_func = tf.subtract(2 * tf.exp(x_in), x_in) 
        return backward_func + tf.stop_gradient(forward_func - backward_func) 

def my_op(func, inp, grad, name=None, victim_op='Identity'):
    # Need to generate a unique name to avoid duplicates.
    rnd_name = 'my_gradient' + str(np.random.randint(0, 1E+8))
    tf.RegisterGradient(rnd_name)(grad)
    g = tf.get_default_graph()
    with g.gradient_override_map({victim_op: rnd_name}):
        return func(inp, name=name)

custom_gradient_with_eager:

此方法使用Tensorflow 1.5或更高版本中可用的tensorflow.contrib.eager來定義用於Tensorflow Ops組成的自定義漸變。

@tfe.custom_gradient
def python_func(x_in):
    def grad_func(grad):
        return grad * ((2 * tf.exp(x_in)) - 1)

    forward_func = tf.subtract(2 * tf.exp(x_in), x_in)
    return forward_func, grad_func

我不確定您如何設法解決問題,但是上述解決方案中的名稱“ op_name”和“ some_name”不會顯示在圖表上。 因此,您將無法使用gradient_override_map({“ op_name”:“ SynthGrad”})。

一種可能的解決方案:如果在前向通行中有一個自定義的tensorflow op x = f(a,b),但希望在后向通行中表現為g(a,b),則可以執行以下操作:

t = g(a,b)out = t + tf.stop_gradient(f(a,b)-t)

但是,您需要在C ++中將g(a,b)定義為具有名稱的虛擬/身份運算符。 稍后,您可以使用gradient_override_map。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM