簡體   English   中英

如何在圖構建時獲取張量的尺寸(在TensorFlow中)?

[英]How to get the dimensions of a tensor (in TensorFlow) at graph construction time?

我正在嘗試一個操作異常的操作。

graph = tf.Graph()
with graph.as_default():
  train_dataset = tf.placeholder(tf.int32, shape=[128, 2])
  embeddings = tf.Variable(
    tf.random_uniform([50000, 64], -1.0, 1.0))
  embed = tf.nn.embedding_lookup(embeddings, train_dataset)
  embed = tf.reduce_sum(embed, reduction_indices=0)

所以我需要知道張量embed的尺寸。 我知道可以在運行時完成,但是對於這樣一個簡單的操作而言,這是太多的工作。 有什么更簡單的方法呢?

我看到大多數人對tf.shape(tensor)tensor.get_shape()感到困惑,讓我們澄清一下:

  1. tf.shape

tf.shape用於動態形狀。 如果張量的形狀可變 ,請使用它。 一個例子:輸入是寬度和高度可變的圖像,我們想要將其調整為其大小的一半,然后可以編寫如下內容:
new_height = tf.shape(image)[0] / 2

  1. tensor.get_shape

tensor.get_shape用於固定形狀,這意味着可以在圖形中推導張量的形狀

結論: tf.shape幾乎可以在任何地方使用,但是只能從圖形中推斷出t.get_shape僅用於形狀。

此帖子中的 Tensor.get_shape()

從文檔

c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(c.get_shape())
==> TensorShape([Dimension(2), Dimension(3)])

一個訪問值的函數:

def shape(tensor):
    s = tensor.get_shape()
    return tuple([s[i].value for i in range(0, len(s))])

例:

batch_size, num_feats = shape(logits)

只需在不運行的情況下打印出構造圖(ops)之后的嵌入內容即可:

import tensorflow as tf

...

train_dataset = tf.placeholder(tf.int32, shape=[128, 2])
embeddings = tf.Variable(
    tf.random_uniform([50000, 64], -1.0, 1.0))
embed = tf.nn.embedding_lookup(embeddings, train_dataset)
print (embed)

這將顯示嵌入張量的形狀:

Tensor("embedding_lookup:0", shape=(128, 2, 64), dtype=float32)

通常,在訓練模型之前檢查所有張量的形狀是很好的。

讓我們簡單到地獄。 如果您想要一個單一的數字作為維數,例如2, 3, 4, etc.,則只需使用tf.rank() 但是,如果您想要張量的確切形狀,請使用tensor.get_shape()

with tf.Session() as sess:
   arr = tf.random_normal(shape=(10, 32, 32, 128))
   a = tf.random_gamma(shape=(3, 3, 1), alpha=0.1)
   print(sess.run([tf.rank(arr), tf.rank(a)]))
   print(arr.get_shape(), ", ", a.get_shape())     


# for tf.rank()    
[4, 3]

# for tf.get_shape()
Output: (10, 32, 32, 128) , (3, 3, 1)

tf.shape方法是TensorFlow靜態方法。 但是,Tensor類也有方法get_shape。 看到

https://www.tensorflow.org/api_docs/python/tf/Tensor#get_shape

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM