简体   繁体   中英

How to save a value of type Tensor in Tensorflow?

I'm going to save the last output vector of a 3D-CNN. A small piece of code is shown bellow and I'm going to save the tensor x.

with tf.variable_scope('pool'):
    x = tf.layers.batch_normalization(
        x, training=mode == tf.estimator.ModeKeys.TRAIN)
    x = relu_op(x)

    axis = tuple(range(len(x.get_shape().as_list())))[1:-1]
    x = tf.reduce_mean(x, axis=axis, name='global_avg_pool')

The tensor x got the value:

{Tensor} Tensor ("pool/global_avg_pool:()", shape=(?, 256), dtype=float32)

during debugging. I have written some code to save this tensor, such as:

import numpy as np
with tf.Session() as sess:     
    np.save('x.npy', sess.run(x), allow_pickle=False)

But I've got the error:

[[node pool/batch_normalization/gamma/read (defined at C:\Users\Nastaran\AppData\Roaming\Python\Python36\site-packages\dltk\networks\regression_classification\resnet.py:115)  = Identity[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](pool/batch_normalization/gamma)]]

How the float values of this tensor would be saved to a vector?

This is working code to save a single tensor, such as the output of your CNN

import tensorflow as tf
import numpy as np

x = tf.linspace(start=1., stop=2., num=3) #A tensor

with tf.Session() as sess:     

    x_value = sess.run(x)
    print(x_value) #[1.  1.5 2. ]

    np.save("x.npy", x_value, allow_pickle=False)

#Check it worked
print(np.load("x.npy")) #[1.  1.5 2. ]

sess.run can also operate on nested data structures, such as lists or dictionaries of tensors, in which case pickle or another library may provide a more general solution.

It's difficult to say what the exact issue is from your (good) minimal example. However, the issue appears to be with running your graph to get x; not the saving. If so, sess.run(x) should give the same error by itself.

I get a similar error if I try to run the output of a batch normalization if I don't initialize variables. Check that you have a variant of sess.run(tf.global_variables_initializer()) at the start of your session.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM