简体   繁体   中英

How does TensorFlow calculate the gradients for the tf.train.GradientDescentOptimizer?

I am trying to understand how TensorFlow computes the gradients for the tf.train.GradientDescentOptimizer .

If I understand section 4.1 in the TensorFlow whitepaper correct, it computes the gradients based on backpropagation by adding nodes to the TensorFlow graph which compute the derivation of a node in the original graph.

When TensorFlow needs to compute the gradient of a tensor C with respect to some tensor I on which C depends, it first finds the path in the computation graph from I to C. Then it backtracks from C to I, and for each operation on the backward path it adds a node to the TensorFlow graph, composing the partial gradients along the backwards path using the chain rule. The newly added node computes the “gradient function” for the corresponding operation in the forward path. A gradient function may be registered by any operation. This function takes as input not only the partial gradients computed already along the backward path, but also, optionally, the inputs and outputs of the forward operation. [Section 4.1 TensorFlow whitepaper]

Question 1: Is there a second node implementation for each TensorFlow node which represents the derivation of the original TensorFlow node?

Question 2: Is there a way to visualize which derivation nodes get added to the graph (or any logs)?

Each node gets corresponding method that computes backprop values (registered using something like @ops.RegisterGradient("Sum") in Python)

You can visualize the graph using method here

However, note that since automatic differentiation code is meant to work for a range of conditions, the graph it creates is quite complicated and not very useful to look at. It's not uncommon to have 10 ops nodes for a simple gradient calculation that could be implemented with 1-2 nodes

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM