[英]How to extract the shape value of a placeholder Tensor in Tensorflow?
[英]How to filter dataset by tensor shape in Tensorflow
我已經從 tfds.load 加載了一個數據集,並且想要丟棄某些干擾正確訓練/對我沒有用的圖像(例如,太小)。
似乎在任何地方都沒有關於這個特定問題的信息,所以我選擇了最適合數據集的 was.filter(predicate) 。 不幸的是,謂詞的輸入具有不確定的形狀(無,無,3),並且正如預期的那樣會引發一個錯誤,即“int”無法與“NoneType”進行比較。
甚至有可能在 tensorflow 中解決這個問題,還是我不應該浪費時間?
偽代碼
ds_train = tfds.load('name')
ds_train = ds_train.map(lambda ds: ds['image'])
ds_train = ds_train.filter(lambda image: image.shape[0] >= 256)
使用tf.data.Dataset
編寫代碼時,應使用tf.shape(tensor)
而不是tensor.shape
,因為tf.data.Dataset
在圖形模式下工作。
引用tf.shape
的文檔:
tf.shape 和 Tensor.shape 在 Eager 模式下應該相同。 在 tf.function 或 compat.v1 上下文中,直到執行時才可能知道所有維度。 因此,在為圖形模式定義自定義層和模型時,更喜歡動態 tf.shape(x) 而不是 static x.shape。
ds_train = ds_train.filter(lambda image: tf.shape(image)[0] >= 256)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.