简体   繁体   English

用于动态提取补丁和展平数据集的 TF 管道

[英]TF pipeline to dynamically extract patches and flatten dataset

I was trying to train an autoencoder on image patches.我试图在图像补丁上训练自动编码器。 My training data consists of single-channel images loaded into a numpy array with shape [10000, 256, 512, 1] .我的训练数据由加载到 numpy 数组中的单通道图像组成,形状为[10000, 256, 512, 1] I know how to extract patches from the images but it is rather non-intuitive that the batches select images and thus the number of points per batch depends on how many patches are extracted per image.我知道如何从图像中提取补丁,但是批次 select 图像非常不直观,因此每批次的点数取决于每个图像提取了多少补丁。 If 32 patches are extracted per image, I'd like the dataset to behave as if it were [320000, 256, 512, 1] so that shuffling and batches pull from several images at a time but with the patches extracted on the fly so that this doesn't have to be kept in memory.如果每个图像提取 32 个补丁,我希望数据集的行为就像[320000, 256, 512, 1]一样,以便一次从多个图像中提取混洗和批次,但是动态提取补丁所以这不必保存在 memory 中。

The closest question I've seen around is Load tensorflow images and create patches but, as I've mentioned, it doesn't provide what I want.我见过的最接近的问题是加载 tensorflow 图像并创建补丁,但正如我所提到的,它不能提供我想要的。

PATCH_SIZE = 64

def extract_patches(imgs, patch_size=PATCH_SIZE, stride=PATCH_SIZE//2):
    # extract patches and reshape them into patch images
    n_channels = imgs.shape[-1]
    if len(imgs.shape) < 4:
        imgs = tf.expand_dims(imgs, axis=0)  
    return tf.reshape(tf.image.extract_patches(imgs,
                                               sizes=[1, patch_size, patch_size, n_channels],
                                               strides=[1, stride, stride, n_channels],
                                               rates=[1, 1, 1, 1],
                                               padding='VALID'),
                      (-1, patch_size, patch_size, n_channels))

batch_size = 8
dataset = (tf.data.Dataset.from_tensor_slices(tf.cast(imgs, tf.float32))
            .map(extract_patches, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
            .shuffle(10*batch_size, reshuffle_each_iteration=True)
            .batch(batch_size)
            )

creates a dataset that returns batches with shape (batch_size, 105, 64, 64, 1) whereas I want a rank 4 tensor with shape (batch_size, 64, 64, 1) and shuffle to operate on patches (rather than collections of patches for each image).创建一个数据集,该数据集返回形状为(batch_size, 105, 64, 64, 1)的批次,而我想要一个形状为(batch_size, 64, 64, 1)的 rank 4 张量并随机对补丁进行操作(而不是 collections 的补丁)每张图片)。 If I put .map at the end of the pipeline如果我将.map放在管道的末尾

batch_size = 8
dataset = (tf.data.Dataset.from_tensor_slices(tf.cast(imgs, tf.float32))
            .shuffle(10*batch_size, reshuffle_each_iteration=True)
            .batch(batch_size)
            .map(extract_patches, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
            )

This does flatten the batches and returns a rank 4 tensor, but in this case each batch has shape (840, 64, 64, 1) .这确实会使批次变平并返回一个等级为 4 的张量,但在这种情况下,每个批次都有形状(840, 64, 64, 1)

What I feel is that what you want to achieve is not possible because you want to shuffle all the patches from all batches of the dataset image and generate it over the fly without saving it in memory.我的感觉是,您想要实现的目标是不可能的,因为您想从所有批次的数据集图像中打乱所有补丁并在运行中生成它而不将其保存在 memory 中。 And because your single image after applying extract_patches is returning 105 patches (because your stride(32) and patch size(64) is not matching) what you can do to achieve rank 4 tensor after applying.batch() is reshaping it as follows,并且因为您在应用 extract_patches 后的单个图像返回 105 个补丁(因为您的步幅(32)和补丁大小(64)不匹配),在 apply.batch() 将其重塑如下之后,您可以做些什么来实现排名 4 的张量,

dataset = (tf.data.Dataset.from_tensor_slices(tf.cast(imgs, tf.float32))
            .map(extract_patches, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
            .shuffle(10*batch_size, reshuffle_each_iteration=True)
            .batch(batch_size)
            .map(lambda x:tf.reshape(x, (batch_size * 105,64,64,1)))
            .batch(batch_size)
            )

I'm not sure but you can try this.我不确定,但你可以试试这个。

**Correction: this approach won't work as.batch() will always return the higher rank of dataset element. **更正:此方法不起作用,因为.batch() 将始终返回更高级别的数据集元素。 As tf.data.Dataset.batch() documents mentions正如 tf.data.Dataset.batch() 文件提到的

Combines consecutive elements of this dataset into batches.将此数据集的连续元素组合成批次。

Final Update: you can do this by using.unbatch() function just after.map(), that will reduce your 4 rank dataset element (105,64,64,3) to 3 rank element (64,64,3) and then you can use.shuffle(), .batch() or any other function just like regular dataset.最终更新:您可以在.map() 之后使用.unbatch() function 来执行此操作,这会将您的 4 秩数据集元素 (105,64,64,3) 减少到 3 秩元素 (64,64,3) 和然后你可以使用 .shuffle()、.batch() 或任何其他 function 就像常规数据集一样。

dataset = (tf.data.Dataset.from_tensor_slices(tf.cast(imgs, tf.float32))
                .map(extract_patches, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
                .unbatch()
                .shuffle(10*batch_size, reshuffle_each_iteration=True)
                .batch(batch_size)
                )

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM