簡體   English   中英

在Tensorflow中獲取未知張量的形狀

[英]Get shape of unknown tensor in Tensorflow

我正在嘗試為OpenAI體育館實現簡單的Q網絡。 我有州供職人員。 狀態用整數表示。 我想要一個熱點。 因此,我這樣做:

input_state = tf.placeholder(tf.int64, shape=(None))
state_oh = tf.one_hot(input_state, env.observation_space.n)

我正在使用(None )除了() )外,因為我想通過批處理來訓練網絡。

我期望state_oh形狀像(None, 16) ,但是我得到了<unknown> 這對我來說是個問題,因為我實現了功能來創建完全連接的層,該層使用tensor.shape確定輸入張量的形狀:

def dense(x, output_size, activation, name=None):  
with tf.name_scope(name, "dense", [x]):       

    w = tf.Variable(tf.random_normal([input_size, output_size]), name="w")
    b = tf.Variable(tf.random_normal([1, output_size]), name="b")
    layer = tf.matmul(x, w) + b
    layer_act = activation(layer)

    return layer_act

這不適用於<unknown>形狀。

如何將一批Integer傳遞給Tensorflow並獲取其第二維(一熱向量的長度)? 我更喜歡不要將輸入的大小顯式傳遞給dense()

我發現,如果我這樣定義占位符:

input_state = tf.placeholder(tf.int64, shape=[None], name="input_state")

我犯了一個非常愚蠢的錯誤。 正確的形狀是[None]而不是(None) ,因為(None)等效於None ,表示“任何形狀”。

使用正確的占位符形狀, state_oh的形狀將是預期的(?, 16)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM