簡體   English   中英

如何在 TensorFlow 數據集上正確使用 tf.function

[英]how to correctly use tf.function with a TensorFlow Dataset

我正在嘗試使用帶有 @tf.function 的 TF 數據集對圖像目錄執行一些預處理。 tf函數內部,圖像文件被讀取為 RAW 字符串張量,我試圖從該張量中取一個切片。 切片,前 13 個字符,表示有關 .ppm 圖像(標題)的信息。 我收到一個錯誤: ValueError: Shape must be rank 1 but is rank 0 for 'Slice' (op: 'Slice') with input shapes: [], [1], [1] 最初我試圖直接切片張量的 .numpy() 屬性( tf函數的filepath輸入參數),但我認為在tf函數中這樣做在語義上是錯誤的。 它也不起作用,因為文件filepath輸入張量沒有 numpy() 屬性(我不明白為什么??)。 tf函數之外,例如在 jupyter 筆記本單元格中,我可以遍歷數據集並獲取具有 numpy 屬性的單個項目,並對其進行切片和所有后續處理。 我確實意識到我對 TF 工作原理的理解可能存在差距(我使用的是 TF 2.0),所以我希望有人能澄清我在閱讀中遺漏的內容。 tf函數的目的是將 ppm 圖像轉換為 png,所以這個函數有一個副作用,但我沒有深入了解這是否可行。

這是代碼:

@tf.function
def ppm_to_png(filepath):
    ppm_bytes = tf.io.read_file(filepath) #.numpy()
    bytes_header = tf.slice(ppm_bytes, [0], [13])
    # bytes_header = ppm_bytes[:13].eval()  # this did not work either with similar error msg
    .
    .
    .
import glob

files = glob.glob(os.path.join(data_dir, '00000/*.ppm'))
dataset = tf.data.Dataset.from_tensor_slices(files)
png_filepaths = dataset.map(ppm_to_png, num_parallel_calls=tf.data.experimental.AUTOTUNE)

要在 TF 中操作字符串值,請查看tf.strings 命名空間

在這種情況下,您可以使用tf.strings.substr

@tf.function
def ppm_to_png(filepath):
  ppm_bytes = tf.io.read_file(filepath)
  bytes_header = tf.strings.substr(ppm_bytes, 0, 13)
  tf.print(bytes_header)

tf.slice只對 Tensor 對象進行操作,對它們的元素無效。 這里, ppm_bytes是一個標量張量,包含一個tf.string類型的tf.string ,其值是文件的整個字符串內容。 因此,當您調用tf.slice ,它只查看標量位,並且不夠聰明,無法意識到您實際上想要獲取該元素的切片。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM