[英]Get shape of unknown tensor in Tensorflow
我正在嘗試為OpenAI體育館實現簡單的Q網絡。 我有州供職人員。 狀態用整數表示。 我想要一個熱點。 因此,我這樣做:
input_state = tf.placeholder(tf.int64, shape=(None))
state_oh = tf.one_hot(input_state, env.observation_space.n)
我正在使用(None
)除了()
)外,因為我想通過批處理來訓練網絡。
我期望state_oh
形狀像(None, 16)
,但是我得到了<unknown>
。 這對我來說是個問題,因為我實現了功能來創建完全連接的層,該層使用tensor.shape
確定輸入張量的形狀:
def dense(x, output_size, activation, name=None):
with tf.name_scope(name, "dense", [x]):
w = tf.Variable(tf.random_normal([input_size, output_size]), name="w")
b = tf.Variable(tf.random_normal([1, output_size]), name="b")
layer = tf.matmul(x, w) + b
layer_act = activation(layer)
return layer_act
這不適用於<unknown>
形狀。
如何將一批Integer傳遞給Tensorflow並獲取其第二維(一熱向量的長度)? 我更喜歡不要將輸入的大小顯式傳遞給dense()
。
我發現,如果我這樣定義占位符:
input_state = tf.placeholder(tf.int64, shape=[None], name="input_state")
我犯了一個非常愚蠢的錯誤。 正確的形狀是[None]
而不是(None)
,因為(None)
等效於None
,表示“任何形狀”。
使用正確的占位符形狀, state_oh
的形狀將是預期的(?, 16)
。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.