簡體   English   中英

如何在 Tensorflow 中按張量形狀過濾數據集

[英]How to filter dataset by tensor shape in Tensorflow

我已經從 tfds.load 加載了一個數據集,並且想要丟棄某些干擾正確訓練/對我沒有用的圖像(例如,太小)。

似乎在任何地方都沒有關於這個特定問題的信息,所以我選擇了最適合數據集的 was.filter(predicate) 。 不幸的是,謂詞的輸入具有不確定的形狀(無,無,3),並且正如預期的那樣會引發一個錯誤,即“int”無法與“NoneType”進行比較。

甚至有可能在 tensorflow 中解決這個問題,還是我不應該浪費時間?

偽代碼

ds_train = tfds.load('name')
ds_train = ds_train.map(lambda ds: ds['image'])
ds_train = ds_train.filter(lambda image: image.shape[0] >= 256)

使用tf.data.Dataset編寫代碼時,應使用tf.shape(tensor)而不是tensor.shape ,因為tf.data.Dataset在圖形模式下工作。

引用tf.shape的文檔:

tf.shape 和 Tensor.shape 在 Eager 模式下應該相同。 在 tf.function 或 compat.v1 上下文中,直到執行時才可能知道所有維度。 因此,在為圖形模式定義自定義層和模型時,更喜歡動態 tf.shape(x) 而不是 static x.shape。

ds_train = ds_train.filter(lambda image: tf.shape(image)[0] >= 256)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM