繁体   English   中英

尝试将NN定义为类时初始化tf.variable时出错

[英]Error in initializing a tf.variable when trying to define the NN as a class

我正在尝试使用python类定义一个简单的tensorflow图,如下所示:

import numpy as np
import tensorflow as tf

class NNclass:

def __init__(self, state_d, action_d, state):
    self.s_dim = state_d
    self.a_dim = action_d
    self.state = state
    self.prediction

@property
def prediction(self):
    a = tf.constant(5, dtype=tf.float32)
    w1 = tf.Variable(np.random.normal(0, 1))
    return tf.add(a, w1)

state = tf.placeholder(tf.float64, shape=[None, 1])
NN_instance = NNclass(1, 2, state)

ses = tf.Session()
ses.run(tf.global_variables_initializer())

nn_input = np.array([[0.5], [0.7]])
print(ses.run(NN_instance.prediction,  feed_dict={state: nn_input}))

当我运行此代码时,出现以下错误:

FailedPreconditionError (see above for traceback): Attempting to use uninitialized value Variable_1

我的看法是,我有一个NNclass实例,并且遍历了tf图,因为def__init__遍历了预测方法。 但是我不明白为什么运行此命令会产生上述错误。 任何帮助请谢谢

创建所有变量后,应调用tf.global_variables_initializer() 在您的示例中, prediction函数定义w1变量,直到ses.run()才初始化。

您可以在__init__函数内部创建变量,如下所示:

class NNclass:
    def __init__(self, state_d, action_d, state):
        self.s_dim = state_d
        self.a_dim = action_d
        self.state = state
        self.a = tf.constant(5, dtype=tf.float32)
        self.w1 = tf.Variable(np.random.normal(0, 1))

    @property
    def prediction(self):
        return tf.add(self.a, self.w1)

在执行操作时将函数的结果传递给sess.run()并不是最佳实践,这会造成混乱。

配置网络的一种更好的做法是创建一个build_graph()函数,其中定义了所有的tensorflow操作。 然后返回您需要计算的张量(更好的是,将它们存储在字典中或将它们另存为对象的属性)。

例:

def build_graph():
  a = tf.constant(5, dtype=tf.float32)
  w1 = tf.Variable(np.random.normal(0, 1))
  a_plus_w = tf.add(a, w1)
  state = tf.placeholder(tf.float64, shape=[None, 1])
  return a_plus_w, state

a_plus_w, state = build_graph()
sess = tf.Session()
sess.run(tf.global_variables_initializer())

nn_input = np.array([[0.5], [0.7]])
print(sess.run(a_plus_w,  feed_dict={state: nn_input}))

您犯的关键错误是您没有将张量流的开发两个阶段分开。 您有一个“构建图”阶段,在其中定义了要执行的所有数学运算,然后是“执行”阶段,在此阶段中,您使用sess.run要求tensorflow为您执行计算。 当您调用sess.run您需要传递tensorflow想要计算的张量(已经在图中定义的tf对象)。 您不应该将tensorflow传递给函数来执行。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM