简体   繁体   中英

get shape of tensor as array in tensorflow

I have a saved model from which I want the final weights that are applied in the final layer. I have loaded the graph and know the where the tensor is but I can't get the shape of the tensor as an array. I know the array has the shape 2048x6. How do I get the actual values like so
[[1,2,3],[1,2,3]...]. Thanks

Here is my code

import tensorflow as tf

saver = tf.train.import_meta_graph('_retrain_checkpoint.meta')
graph = tf.get_default_graph()

tensor = tf.get_default_graph().get_tensor_by_name("final_retrain_ops/weights/final_weights:0")

print(tensor)
print(tf.TensorShape(tensor.get_shape()).as_list()




>>>Tensor("final_retrain_ops/weights/final_weights:0", shape=(2048, 6), dtype=float32_ref)
>>>(2048, 6)

To print the values of the weight tensor, you can do the following:

with tf.Session() as sess:
    print( sess.run( tensor ) )

sess.run() evaluates the tensor(s) in it argument, which here just means it will print the values.

There is a bit of an issue, however, that your code only loads the structure of the graph ( tf.train.import_meta_graph('_retrain_checkpoint.meta') ), not the pretrained values. Therefore you get the error that you're trying to use uninitialized values.

You need to have something like:

saver.restore(sess,tf.train.latest_checkpoint('./'))

to load it, right after sess has been defined, and of course, you need to point to the correct checkpoint directory instead of ./ .

So something like this:

with tf.Session() as sess:
    saver.restore(sess,tf.train.latest_checkpoint('./'))
    print( sess.run( tensor ) )

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM