简体   繁体   中英

How are the past gradients accumulated in Momentum/Adam when we preprocess the gradients before tf.train.Optimizer.apply_gradients()?

I want to preprocess the gradients before apply_gradients , and want the past gradients to be accumulated upon processed gradients when tf.train.MomentumOptimizer or tf.train.AdamOptimizer is used. I know we can preprocess the gradients between compute_gradients and apply_gradients as shown here :

# Create an optimizer.
opt = GradientDescentOptimizer(learning_rate=0.1)

# Compute the gradients for a list of variables.
grads_and_vars = opt.compute_gradients(loss, <list of variables>)

# grads_and_vars is a list of tuples (gradient, variable).  Do whatever you
# need to the 'gradient' part, for example cap them, etc.
capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]

# Ask the optimizer to apply the capped gradients.
opt.apply_gradients(capped_grads_and_vars)

My question is, in the above case, is the history gradients accumulated upon the capped gradients or non-capped ones?

Thanks!

All of the state that optimizers keep is updated in apply_gradients. There is a bit of a complicated call chain (best followed in optimizer.py ), but the short of it is that apply_gradients eventually calls one of apply_sparse or apply_dense (ignoring resource variables).

Going back to Adam, apply_sparse is relatively easy to read since it's an agglomeration of ops rather than a single C++ op. You can see that it updates all of the moments and the variables.

So to answer your question, if you cap the gradients before calling apply_gradients, then capped gradients will be accumulated in Adam's moments (and likewise for other optimizers).

There is a bit of a gotcha if you're dealing with sparse gradients (IndexedSlices), since these are disaggregated as they pass through the graph. So if you cap them in the disaggregated form, repeated indices may sum up to more than your cap. This will only be an issue if you're doing gather() or using embeddings, but it's worth keeping in mind.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM