简体   繁体   中英

ValueError: Tensor(“BN_1/moments/Squeeze:0”, shape=(32, 256, 32), dtype=float32) must be from the same graph as Tensor

I'm trying to get started with TensorFlow in python, building a simple CNN with batch normalization. But when i create a new graph to run, exception happens to BN.

My key codes is as follows

**# exception here**
def batch_norm(x, beta, gamma, phase_train, scope='bn', decay=0.9, eps=1e-5):
    with tf.variable_scope(scope):
        batch_mean, batch_var = tf.nn.moments(x, [0], name='moments')
        ema = tf.train.ExponentialMovingAverage(decay=decay)

        def mean_var_with_update():
            ema_apply_op = ema.apply([batch_mean, batch_var])
            with tf.control_dependencies([ema_apply_op]):
                return tf.identity(batch_mean), tf.identity(batch_var)

        mean, var = tf.cond(phase_train, mean_var_with_update, lambda: (ema.average(batch_mean), ema.average(batch_var)))
        normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)
    return normed

training code:

# start training
output = conv2d_net()
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=output, labels=Y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.002).minimize(loss)

predict = tf.reshape(output, [-1, MAX_CAPTCHA, CHAR_SET_LEN])
max_idx_p = tf.argmax(predict, 2)
max_idx_l = tf.argmax(tf.reshape(Y, [-1, MAX_CAPTCHA, CHAR_SET_LEN]), 2)
correct_pred = tf.equal(max_idx_p, max_idx_l)
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    step = 0
    while True:
        batch_x, batch_y = get_next_batch(64)
        _, loss_ = sess.run([optimizer, loss],
                            feed_dict={X: batch_x, Y: batch_y, keep_prob: 0.75, train_phase: True})
        print(step, loss_)

        if step % 10 == 0 and step != 0:
            batch_x_test, batch_y_test = get_next_batch(100)
            acc = sess.run(accuracy,
                           feed_dict={X: batch_x_test, Y: batch_y_test, keep_prob: 1., train_phase: False})
            print("step %s,accuracy:%s" % (step, acc))
            if acc > 0.05:
                # stop training and save parameters in layer
                result_weights['wc1'] = weights['wc1'].eval(sess)
                ...
                break
        step += 1

Create new graph for exporting:

EXPORT_DIR = './model'
if os.path.exists(EXPORT_DIR):
    shutil.rmtree(EXPORT_DIR)

g = tf.Graph()
with g.as_default():
    x_2 = tf.placeholder(tf.float32, shape=[None, IMAGE_HEIGHT * IMAGE_WIDTH], name="input")
    x_image = tf.reshape(x_2, shape=[-1, IMAGE_HEIGHT, IMAGE_WIDTH, 1])

    # fill trained parameters and create new cnn layers
    WC1 = tf.constant(result_weights['wc1'], name="WC1")
    ...
    **# crash here!!!**
    CONV1 = conv2d(WC1, BC1, x_image, tf.constant(0.0, shape=[32]),
               tf.random_normal(shape=[32], mean=1.0, stddev=0.02), scope='BN_1')

    OUTPUT = tf.add(tf.matmul(FULL1, W_OUT), B_OUT)
    OUTPUT = tf.nn.sigmoid(OUTPUT, name="output")

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    graph_def = g.as_graph_def()
    tf.train.write_graph(graph_def, EXPORT_DIR, 'phone_model_graph.pb', as_text=True)

I create a new graph at last. The exception means it uses incorrect parameter in old training graph. How to explain it?

Thank you very much!

Log is: 日志

I call batch_norm in fuction conv2d. It seems no tensor passed to the new graph.

def conv2d(w, b, x, tf_constant, tf_random_normal, scope, keep_p=1., phase=tf.constant(False)):
out = tf.nn.bias_add(tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME'), b)
out = batch_norm(out, tf_constant, tf_random_normal, phase, scope=scope)
out = tf.nn.relu(out)
out = tf.nn.max_pool(out, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
out = tf.nn.dropout(out, keep_p)
return out

I create a new graph at last.

That's the key statement here: upon creation of a new graph one can't use any tensor from the old graph. See a detailed explanation in this question . According to the stacktrace, at least one of the tensors that is passed to the batch_norm is defined before g.as_default() , that's why tensorflow crashes. From your code snippets it's unclear how exactly the batch_norm is called, so I can't say which one.

You can check this hypothesis by printing x.graph and g and checking if these values are different. In order to avoid this problem you can either do all the work inside one graph (which is a recommended way) or define both graphs in different python scopes thus making impossible to accidentally reuse the same python variable in two graphs.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM