簡體   English   中英

如何在張量流中交替訓練op?

[英]How to alternate train op's in tensorflow?

我正在實施交替培訓計划。 該圖包含兩個訓練操作。 培訓應在這些之間交替進行。

這是相關研究像這樣

以下是一個小例子。 但是,似乎每個步驟都會更新兩個操作。 我如何明確地在這些之間交替?

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
# Import data
mnist = input_data.read_data_sets('/tmp/tensorflow/mnist/input_data', one_hot=True)

# Create the model
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]), name='weights')
b = tf.Variable(tf.zeros([10]), name='biases')
y = tf.matmul(x, W) + b

# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
global_step = tf.Variable(0, trainable=False)

tvars1 = [b]
train_step1 = tf.train.GradientDescentOptimizer(0.5).apply_gradients(zip(tf.gradients(cross_entropy, tvars1), tvars1), global_step)
tvars2 = [W]
train_step2 = tf.train.GradientDescentOptimizer(0.5).apply_gradients(zip(tf.gradients(cross_entropy, tvars2), tvars2), global_step)
train_step = tf.cond(tf.equal(tf.mod(global_step,2), 0), true_fn= lambda:train_step1, false_fn=lambda : train_step2)


sess = tf.InteractiveSession()
tf.global_variables_initializer().run()


# Train
for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    if i % 100 == 0:
        print(sess.run([cross_entropy, global_step], feed_dict={x: mnist.test.images,
                                         y_: mnist.test.labels}))

這導致

[2.0890141, 2]
[0.38277805, 202]
[0.33943111, 402]
[0.32314575, 602]
[0.3113254, 802]
[0.3006627, 1002]
[0.2965056, 1202]
[0.29858461, 1402]
[0.29135355, 1602]
[0.29006076, 1802]      

全局步驟迭代到1802,因此,每次調用train_step ,將同時執行兩個火車操作。 (例如,當永遠為假的條件為tf.equal(global_step,-1)時,也會發生這種情況。)

我的問題是如何在執行train_step1train_step2之間train_step2

我認為最簡單的方法就是

for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  if i % 2 == 0:
    sess.run(train_step1, feed_dict={x: batch_xs, y_: batch_ys})
  else:
    sess.run(train_step2, feed_dict={x: batch_xs, y_: batch_ys})

但如果有必要通過tensorflow條件流進行切換,請按以下方式進行操作:

optimizer = tf.train.GradientDescentOptimizer(0.5)
train_step = tf.cond(tf.equal(tf.mod(global_step, 2), 0),
                     true_fn=lambda: optimizer.apply_gradients(zip(tf.gradients(cross_entropy, tvars1), tvars1), global_step),
                     false_fn=lambda: optimizer.apply_gradients(zip(tf.gradients(cross_entropy, tvars2), tvars2), global_step))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM