![](/img/trans.png)
[英]Error when using tf.get_variable as alternativ for tf.Variable in Tensorflow
[英]Error in initializing a tf.variable when trying to define the NN as a class
我正在嘗試使用python類定義一個簡單的tensorflow圖,如下所示:
import numpy as np
import tensorflow as tf
class NNclass:
def __init__(self, state_d, action_d, state):
self.s_dim = state_d
self.a_dim = action_d
self.state = state
self.prediction
@property
def prediction(self):
a = tf.constant(5, dtype=tf.float32)
w1 = tf.Variable(np.random.normal(0, 1))
return tf.add(a, w1)
state = tf.placeholder(tf.float64, shape=[None, 1])
NN_instance = NNclass(1, 2, state)
ses = tf.Session()
ses.run(tf.global_variables_initializer())
nn_input = np.array([[0.5], [0.7]])
print(ses.run(NN_instance.prediction, feed_dict={state: nn_input}))
當我運行此代碼時,出現以下錯誤:
FailedPreconditionError (see above for traceback): Attempting to use uninitialized value Variable_1
我的看法是,我有一個NNclass實例,並且遍歷了tf圖,因為def__init__遍歷了預測方法。 但是我不明白為什么運行此命令會產生上述錯誤。 任何幫助請謝謝
創建所有變量后,應調用tf.global_variables_initializer()
。 在您的示例中, prediction
函數定義w1
變量,直到ses.run()
才初始化。
您可以在__init__
函數內部創建變量,如下所示:
class NNclass:
def __init__(self, state_d, action_d, state):
self.s_dim = state_d
self.a_dim = action_d
self.state = state
self.a = tf.constant(5, dtype=tf.float32)
self.w1 = tf.Variable(np.random.normal(0, 1))
@property
def prediction(self):
return tf.add(self.a, self.w1)
在執行操作時將函數的結果傳遞給sess.run()
並不是最佳實踐,這會造成混亂。
配置網絡的一種更好的做法是創建一個build_graph()
函數,其中定義了所有的tensorflow操作。 然后返回您需要計算的張量(更好的是,將它們存儲在字典中或將它們另存為對象的屬性)。
例:
def build_graph():
a = tf.constant(5, dtype=tf.float32)
w1 = tf.Variable(np.random.normal(0, 1))
a_plus_w = tf.add(a, w1)
state = tf.placeholder(tf.float64, shape=[None, 1])
return a_plus_w, state
a_plus_w, state = build_graph()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
nn_input = np.array([[0.5], [0.7]])
print(sess.run(a_plus_w, feed_dict={state: nn_input}))
您犯的關鍵錯誤是您沒有將張量流的開發兩個階段分開。 您有一個“構建圖”階段,在其中定義了要執行的所有數學運算,然后是“執行”階段,在此階段中,您使用sess.run
要求tensorflow為您執行計算。 當您調用sess.run
您需要傳遞tensorflow想要計算的張量(已經在圖中定義的tf對象)。 您不應該將tensorflow傳遞給函數來執行。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.