简体   繁体   English

如何在 TensorFlow 数据集上正确使用 tf.function

[英]how to correctly use tf.function with a TensorFlow Dataset

I'm trying to use TF Datasets with a @tf.function to perform some preprocessing on a directory of images.我正在尝试使用带有 @tf.function 的 TF 数据集对图像目录执行一些预处理。 Inside the tf function the image file is read as a RAW string tensor and I'm trying to take a slice from that tensor.tf函数内部,图像文件被读取为 RAW 字符串张量,我试图从该张量中取一个切片。 The slice, the first 13 characters, represent info about .ppm images (header).切片,前 13 个字符,表示有关 .ppm 图像(标题)的信息。 I get an error: ValueError: Shape must be rank 1 but is rank 0 for 'Slice' (op: 'Slice') with input shapes: [], [1], [1] .我收到一个错误: ValueError: Shape must be rank 1 but is rank 0 for 'Slice' (op: 'Slice') with input shapes: [], [1], [1] Initially I was trying to directly slice the .numpy() attribute of the tensor ( filepath input parameter to the tf function), but I think it is semantically wrong to do this inside a tf function.最初我试图直接切片张量的 .numpy() 属性( tf函数的filepath输入参数),但我认为在tf函数中这样做在语义上是错误的。 It also didn't work as the filepath input tensor does not have a numpy() attribute (I don't understand why??).它也不起作用,因为文件filepath输入张量没有 numpy() 属性(我不明白为什么??)。 Outside of the tf function, eg in a jupyter notebook cell, I can iterate over the dataset and get individual items which have a numpy attribute and do a slice and all subsequent processing on it just fine.tf函数之外,例如在 jupyter 笔记本单元格中,我可以遍历数据集并获取具有 numpy 属性的单个项目,并对其进行切片和所有后续处理。 I do realize there may be a gap in my understanding of how TF works (I am using TF 2.0), so I hope someone can clarify what I missed in my readings.我确实意识到我对 TF 工作原理的理解可能存在差距(我使用的是 TF 2.0),所以我希望有人能澄清我在阅读中遗漏的内容。 The purpose of the tf function is convert the ppm images to png, so there is a side effect of this function, but I did not get that far to find out if this is possible to do. tf函数的目的是将 ppm 图像转换为 png,所以这个函数有一个副作用,但我没有深入了解这是否可行。

Here's the code:这是代码:

@tf.function
def ppm_to_png(filepath):
    ppm_bytes = tf.io.read_file(filepath) #.numpy()
    bytes_header = tf.slice(ppm_bytes, [0], [13])
    # bytes_header = ppm_bytes[:13].eval()  # this did not work either with similar error msg
    .
    .
    .
import glob

files = glob.glob(os.path.join(data_dir, '00000/*.ppm'))
dataset = tf.data.Dataset.from_tensor_slices(files)
png_filepaths = dataset.map(ppm_to_png, num_parallel_calls=tf.data.experimental.AUTOTUNE)

To manipulate string values in TF, have a look at the tf.strings namespace .要在 TF 中操作字符串值,请查看tf.strings 命名空间

In this case, you can use tf.strings.substr :在这种情况下,您可以使用tf.strings.substr

@tf.function
def ppm_to_png(filepath):
  ppm_bytes = tf.io.read_file(filepath)
  bytes_header = tf.strings.substr(ppm_bytes, 0, 13)
  tf.print(bytes_header)

tf.slice only operates on the Tensor objects, and doesn't work on their elements. tf.slice只对 Tensor 对象进行操作,对它们的元素无效。 Here, ppm_bytes is a scalar Tensor containing a single element of type tf.string , and whose value is the entire string contents of the file.这里, ppm_bytes是一个标量张量,包含一个tf.string类型的tf.string ,其值是文件的整个字符串内容。 So when you call tf.slice , it only looks at the scalar bit, and is not smart enough to realize that you actually want to take a slice of that element instead.因此,当您调用tf.slice ,它只查看标量位,并且不够聪明,无法意识到您实际上想要获取该元素的切片。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM