I have a saved model from which I want the final weights that are applied in the final layer. I have loaded the graph and know the where the tensor is but I can't get the shape of the tensor as an array. I know the array has the shape 2048x6. How do I get the actual values like so
[[1,2,3],[1,2,3]...]. Thanks
Here is my code
import tensorflow as tf
saver = tf.train.import_meta_graph('_retrain_checkpoint.meta')
graph = tf.get_default_graph()
tensor = tf.get_default_graph().get_tensor_by_name("final_retrain_ops/weights/final_weights:0")
print(tensor)
print(tf.TensorShape(tensor.get_shape()).as_list()
>>>Tensor("final_retrain_ops/weights/final_weights:0", shape=(2048, 6), dtype=float32_ref)
>>>(2048, 6)
To print the values of the weight tensor, you can do the following:
with tf.Session() as sess:
print( sess.run( tensor ) )
sess.run()
evaluates the tensor(s) in it argument, which here just means it will print the values.
There is a bit of an issue, however, that your code only loads the structure of the graph ( tf.train.import_meta_graph('_retrain_checkpoint.meta')
), not the pretrained values. Therefore you get the error that you're trying to use uninitialized values.
You need to have something like:
saver.restore(sess,tf.train.latest_checkpoint('./'))
to load it, right after sess
has been defined, and of course, you need to point to the correct checkpoint directory instead of ./
.
So something like this:
with tf.Session() as sess:
saver.restore(sess,tf.train.latest_checkpoint('./'))
print( sess.run( tensor ) )
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.