簡體   English   中英

嘗試將NN定義為類時初始化tf.variable時出錯

[英]Error in initializing a tf.variable when trying to define the NN as a class

我正在嘗試使用python類定義一個簡單的tensorflow圖,如下所示:

import numpy as np
import tensorflow as tf

class NNclass:

def __init__(self, state_d, action_d, state):
    self.s_dim = state_d
    self.a_dim = action_d
    self.state = state
    self.prediction

@property
def prediction(self):
    a = tf.constant(5, dtype=tf.float32)
    w1 = tf.Variable(np.random.normal(0, 1))
    return tf.add(a, w1)

state = tf.placeholder(tf.float64, shape=[None, 1])
NN_instance = NNclass(1, 2, state)

ses = tf.Session()
ses.run(tf.global_variables_initializer())

nn_input = np.array([[0.5], [0.7]])
print(ses.run(NN_instance.prediction,  feed_dict={state: nn_input}))

當我運行此代碼時,出現以下錯誤:

FailedPreconditionError (see above for traceback): Attempting to use uninitialized value Variable_1

我的看法是,我有一個NNclass實例,並且遍歷了tf圖,因為def__init__遍歷了預測方法。 但是我不明白為什么運行此命令會產生上述錯誤。 任何幫助請謝謝

創建所有變量后,應調用tf.global_variables_initializer() 在您的示例中, prediction函數定義w1變量,直到ses.run()才初始化。

您可以在__init__函數內部創建變量,如下所示:

class NNclass:
    def __init__(self, state_d, action_d, state):
        self.s_dim = state_d
        self.a_dim = action_d
        self.state = state
        self.a = tf.constant(5, dtype=tf.float32)
        self.w1 = tf.Variable(np.random.normal(0, 1))

    @property
    def prediction(self):
        return tf.add(self.a, self.w1)

在執行操作時將函數的結果傳遞給sess.run()並不是最佳實踐,這會造成混亂。

配置網絡的一種更好的做法是創建一個build_graph()函數,其中定義了所有的tensorflow操作。 然后返回您需要計算的張量(更好的是,將它們存儲在字典中或將它們另存為對象的屬性)。

例:

def build_graph():
  a = tf.constant(5, dtype=tf.float32)
  w1 = tf.Variable(np.random.normal(0, 1))
  a_plus_w = tf.add(a, w1)
  state = tf.placeholder(tf.float64, shape=[None, 1])
  return a_plus_w, state

a_plus_w, state = build_graph()
sess = tf.Session()
sess.run(tf.global_variables_initializer())

nn_input = np.array([[0.5], [0.7]])
print(sess.run(a_plus_w,  feed_dict={state: nn_input}))

您犯的關鍵錯誤是您沒有將張量流的開發兩個階段分開。 您有一個“構建圖”階段,在其中定義了要執行的所有數學運算,然后是“執行”階段,在此階段中,您使用sess.run要求tensorflow為您執行計算。 當您調用sess.run您需要傳遞tensorflow想要計算的張量(已經在圖中定義的tf對象)。 您不應該將tensorflow傳遞給函數來執行。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM