简体   繁体   中英

How to use a tensorflow tensor value in a formula?

I have a quick question. I am developing a model in tensorflow, and need to use the iteration number in a formula during the construction phase. I know how to use global_step, but I am not using an already existing optimizer. I am calculating my own gradients with

grad_W, grad_b = tf.gradients(xs=[W, b], ys=cost)
grad_W = grad_W +rnd.normal(0,1.0/(1+epoch)**0.55)

and then using

new_W = W.assign(W - learning_rate * (grad_W))
new_b = b.assign(b - learning_rate * (grad_b))

and would like to use the epoch value in the formula before updating my weights. How can I do it in the best way possible? I have a sess.run() part and would like to pass to the model the epoch number, but cannot directly use a tensor. From my run call

_, _, cost_ = sess.run([new_W, new_b ,cost], 
      feed_dict = {X_: X_train_tr, Y: labels_, learning_rate: learning_r})

I would like to pass the epoch number. How do you usually do it?

Thanks in advance, Umberto

EDIT :

Thanks for the hints. So seems to work

grad_W = grad_W + tf.random_normal(grad_W.shape, 
      0.0,1.0/tf.pow(0.01+tf.cast(epochv, tf.float32),0.55))

but I still have to see if that is what I need and if is working as intended. Ideas and Feedback would be great!

You can define epoch as a non-trainable tf.Variable in your graph and increment it at the end of each epoch. You can define an operation with tf.assign_add to do the incrementation and run it end of each epoch.

Instead of rnd.normal you will also need to use tf.random_normal then.

Example:

epoch = tf.Variable(0, trainable=False) # 0 is initial value
# increment by 1 when the next op is run
epoch_incr_op = tf.assign_add(epoch, 1, name='incr_epoch')

# Define any operations that depend on 'epoch'
# Note we need to cast the integer 'epoch' to float to use in tf.pow
grad_W = grad_W + tf.random_normal(grad_W.shape, 0.0,
                                  1.0/tf.pow(1+tf.cast(epoch, tf.float32), 0.55))

# Training loop
while running_epoch:
    _, _, cost_ = sess.run([new_W, new_b ,cost], 
       feed_dict = {X_: X_train_tr, Y: labels_, learning_rate: learning_r})

# At end of epoch, increment epoch counter
sess.run(epoch_incr_op)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM