简体   繁体   中英

How to initialize a tf.Variable with a tf.constant or a numpy array?

I am trying to initialize a tf.Variable() in a tf.InteractiveSession() . I already have some pre-trained weights which are individual numpy files. How do I effectively initialize the variable with these numpy values ?

I have gone through the following options:

  1. Using tf.assign()
  2. using sess.run() directly during tf.Variable() creation

Seems like the values are not correctly initialized. Following is some code I have tried. Let me know which is the correct one ?

def read_numpy(file):
    return np.fromfile(file,dtype='f')

def build_network():
    with tf.get_default_graph().as_default():
        x = tf.Variable(tf.constant(read_numpy('foo.npy')),name='var1')
        sess = tf.get_default_session()
        with sess.as_default():
            sess.run(tf.global_variables_initializer())

sess = tf.InteractiveSession()
with sess.as_default():
    build_network()

Is this the correct way to do it ? I have printed the session object, and it is the same session used throughout.

edit : Currently it seems like using sess.run(tf.global_variables_initializer()) is calling a random initialize op

tf.Variable() accepts numpy arrays as initial values:

import tensorflow as tf
import numpy as np

init = np.ones((2, 2))
x = tf.Variable(init) # <-- set initial value to assign to a variable

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) # <-- this will assign the init value
    print(x.eval())
# [[1. 1.]
#  [1. 1.]]

So just use the numpy array to initialize, no need to convert it to a tensor first.

Alternatively, you could also use tf.Variable.load() to assign values from numpy array to a variable within a session context:

import tensorflow as tf
import numpy as np

x = tf.Variable(tf.zeros((2, 2)))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    init = np.ones((2, 2))
    x.load(init)
    print(x.eval())
# [[1. 1.]
#  [1. 1.]]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM