简体   繁体   English

如何将 map function 应用于 tf.Tensor

[英]how to apply map function to the tf.Tensor

dataset = tf.data.Dataset.from_tensor_slices((images,boxes))
function_to_map = lambda x,y: func3(x,y)
fast_benchmark(dataset.map(function_to_map).batch(1).prefetch(tf.data.experimental.AUTOTUNE))

Now I here is the func3现在我这里是func3

def fast_benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    print('dataset->',dataset)
    for _ in tf.data.Dataset.range(num_epochs):
        for _,__ in dataset:
            print(_,__)
            break
            pass

the ooutput of print is print 的输出是

tf.Tensor([b'/media/jake/mark-4tb3/input/datasets/pascal/VOCtrainval_11-May-2012/VOCdevkit/VOC2012/JPEGImages/2008_000008.jpg'], shape=(1,), dtype=string) <tf.RaggedTensor [[[52, 86, 470, 419], [157, 43, 288, 166]]]>

what I want to do in func3()我想在 func3() 中做什么
want to change image directory to the real image and run the batch想要将图像目录更改为真实图像并运行批处理

You need to extract string form the tensor and use the appropriate image reading function.您需要从张量中提取字符串并使用适当的图像读取 function。 Below are the steps to be implemented in the code to achieve this.以下是要在代码中实现的步骤。

  1. You have to decorate the map function with tf.py_function(get_path, [x], [tf.float32]) .你必须用tf.py_function(get_path, [x], [tf.float32])装饰 map function 。 You can find more about tf.py_function here .您可以在此处找到有关 tf.py_function 的更多信息。 In tf.py_function , first argument is the name of map function, second argument is the element to be passed to map function and final argument is the return type. In tf.py_function , first argument is the name of map function, second argument is the element to be passed to map function and final argument is the return type.
  2. You can get your string part by using bytes.decode(file_path.numpy()) in map function.您可以通过在 map function 中使用bytes.decode(file_path.numpy())来获取您的字符串部分。
  3. Use appropriate function to load your image.使用适当的 function 加载您的图像。 We are using load_img .我们正在使用load_img

In the below simple program, we are using tf.data.Dataset.list_files to read path of the image.在下面的简单程序中,我们使用tf.data.Dataset.list_files来读取图像的路径。 Next in the map function we are reading the image using load_img and later doing the tf.image.central_crop function to crop central part of the image.接下来在map function 中,我们使用load_img读取图像,然后执行tf.image.central_crop function 以裁剪图像的中心部分。

Code -代码 -

%tensorflow_version 2.x
import tensorflow as tf
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array, array_to_img
from matplotlib import pyplot as plt
import numpy as np

def load_file_and_process(path):
    image = load_img(bytes.decode(path.numpy()), target_size=(224, 224))
    image = img_to_array(image)
    image = tf.image.central_crop(image, np.random.uniform(0.50, 1.00))
    return image

train_dataset = tf.data.Dataset.list_files('/content/bird.jpg')
train_dataset = train_dataset.map(lambda x: tf.py_function(load_file_and_process, [x], [tf.float32]))

for f in train_dataset:
  for l in f:
    image = np.array(array_to_img(l))
    plt.imshow(image)

Output - Output -

在此处输入图像描述

Hope this answers your question.希望这能回答你的问题。 Happy Learning.快乐学习。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM