![](/img/trans.png)
[英]Error when using tf.get_variable as alternativ for tf.Variable in Tensorflow
[英]Error in initializing a tf.variable when trying to define the NN as a class
我正在尝试使用python类定义一个简单的tensorflow图,如下所示:
import numpy as np
import tensorflow as tf
class NNclass:
def __init__(self, state_d, action_d, state):
self.s_dim = state_d
self.a_dim = action_d
self.state = state
self.prediction
@property
def prediction(self):
a = tf.constant(5, dtype=tf.float32)
w1 = tf.Variable(np.random.normal(0, 1))
return tf.add(a, w1)
state = tf.placeholder(tf.float64, shape=[None, 1])
NN_instance = NNclass(1, 2, state)
ses = tf.Session()
ses.run(tf.global_variables_initializer())
nn_input = np.array([[0.5], [0.7]])
print(ses.run(NN_instance.prediction, feed_dict={state: nn_input}))
当我运行此代码时,出现以下错误:
FailedPreconditionError (see above for traceback): Attempting to use uninitialized value Variable_1
我的看法是,我有一个NNclass实例,并且遍历了tf图,因为def__init__遍历了预测方法。 但是我不明白为什么运行此命令会产生上述错误。 任何帮助请谢谢
创建所有变量后,应调用tf.global_variables_initializer()
。 在您的示例中, prediction
函数定义w1
变量,直到ses.run()
才初始化。
您可以在__init__
函数内部创建变量,如下所示:
class NNclass:
def __init__(self, state_d, action_d, state):
self.s_dim = state_d
self.a_dim = action_d
self.state = state
self.a = tf.constant(5, dtype=tf.float32)
self.w1 = tf.Variable(np.random.normal(0, 1))
@property
def prediction(self):
return tf.add(self.a, self.w1)
在执行操作时将函数的结果传递给sess.run()
并不是最佳实践,这会造成混乱。
配置网络的一种更好的做法是创建一个build_graph()
函数,其中定义了所有的tensorflow操作。 然后返回您需要计算的张量(更好的是,将它们存储在字典中或将它们另存为对象的属性)。
例:
def build_graph():
a = tf.constant(5, dtype=tf.float32)
w1 = tf.Variable(np.random.normal(0, 1))
a_plus_w = tf.add(a, w1)
state = tf.placeholder(tf.float64, shape=[None, 1])
return a_plus_w, state
a_plus_w, state = build_graph()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
nn_input = np.array([[0.5], [0.7]])
print(sess.run(a_plus_w, feed_dict={state: nn_input}))
您犯的关键错误是您没有将张量流的开发两个阶段分开。 您有一个“构建图”阶段,在其中定义了要执行的所有数学运算,然后是“执行”阶段,在此阶段中,您使用sess.run
要求tensorflow为您执行计算。 当您调用sess.run
您需要传递tensorflow想要计算的张量(已经在图中定义的tf对象)。 您不应该将tensorflow传递给函数来执行。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.