繁体   English   中英

使用extract_image_patches后重建图像

[英]Reconstructing an image after using extract_image_patches

我有一个自动编码器,它将图像作为输入并生成一个新图像作为输出。

输入图像 (1x1024x1024x3) 在被馈送到网络之前被分成块 (1024x32x32x3)。

一旦我有了输出,还有一批大小为 1024x32x32x3 的补丁,我希望能够重建一个 1024x1024x3 的图像。 我以为我通过简单的重塑就可以解决这个问题,但这就是发生的事情。

首先是 Tensorflow 读取的图像:输入图像

我用以下代码修补了图像

patch_size = [1, 32, 32, 1]
patches = tf.extract_image_patches([image],
    patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [1024, 32, 32, 3])

以下是这张图片中的几个补丁:

修补输入 #168修补输入 #169

但是当我将这个补丁数据重新塑造成一个图像时,事情就会变成梨形。

reconstructed = tf.reshape(patches, [1, 1024, 1024, 3])
converted = tf.image.convert_image_dtype(reconstructed, tf.uint8)
encoded = tf.image.encode_png(converted)

重构输出

在这个例子中,在修补和重建之间没有进行任何处理。 我制作了一个可用于测试此行为的代码版本 要使用它,请运行以下命令:

echo "/path/to/test-image.png" > inputs.txt
mkdir images
python3 image_test.py inputs.txt images

该代码将为每个输入图像中的 1024 个补丁生成一个输入图像、一个补丁图像和一个输出图像,因此如果您只关心保存所有补丁,请注释掉创建输入和输出图像的行。

有人请解释发生了什么:(

由于我也为此苦苦挣扎,因此我发布了一个可能对其他人有用的解决方案。 诀窍是要意识到tf.extract_image_patches的倒数是它的梯度,正如这里建议的 由于这个op的梯度是在Tensorflow中实现的,所以很容易构建重构函数:

import tensorflow as tf
from keras import backend as K
import numpy as np

def extract_patches(x):
    return tf.extract_image_patches(
        x,
        (1, 3, 3, 1),
        (1, 1, 1, 1),
        (1, 1, 1, 1),
        padding="VALID"
    )

def extract_patches_inverse(x, y):
    _x = tf.zeros_like(x)
    _y = extract_patches(_x)
    grad = tf.gradients(_y, _x)[0]
    # Divide by grad, to "average" together the overlapping patches
    # otherwise they would simply sum up
    return tf.gradients(_y, _x, grad_ys=y)[0] / grad

# Generate 10 fake images, last dimension can be different than 3
images = np.random.random((10, 28, 28, 3)).astype(np.float32)
# Extract patches
patches = extract_patches(images)
# Reconstruct image
# Notice that original images are only passed to infer the right shape
images_reconstructed = extract_patches_inverse(images, patches) 

# Compare with original (evaluating tf.Tensor into a numpy array)
# Here using Keras session
images_r = images_reconstructed.eval(session=K.get_session())

print (np.sum(np.square(images - images_r))) 
# 2.3820458e-11

使用更新#2 - 您的任务的一个小例子:(TF 1.0)

考虑大小 (4,4,1) 的图像转换为大小 (4,2,2,1) 的块并将它们重建回图像。

import tensorflow as tf
image = tf.constant([[[1],   [2],  [3],  [4]],
                 [[5],   [6],  [7],  [8]],
                 [[9],  [10], [11],  [12]],
                [[13], [14], [15],  [16]]])

patch_size = [1,2,2,1]
patches = tf.extract_image_patches([image],
    patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [4, 2, 2, 1])
reconstructed = tf.reshape(patches, [1, 4, 4, 1])
rec_new = tf.space_to_depth(reconstructed,2)
rec_new = tf.reshape(rec_new,[4,4,1])

sess = tf.Session()
I,P,R_n = sess.run([image,patches,rec_new])
print(I)
print(I.shape)
print(P.shape)
print(R_n)
print(R_n.shape)

输出:

[[[ 1][ 2][ 3][ 4]]
  [[ 5][ 6][ 7][ 8]]
  [[ 9][10][11][12]]
  [[13][14][15][16]]]
(4, 4, 1)
(4, 2, 2, 1)
[[[ 1][ 2][ 3][ 4]]
  [[ 5][ 6][ 7][ 8]]
  [[ 9][10][11][12]]
  [[13][14][15][16]]]
(4,4,1)

更新 - 3 个通道(调试..)

仅适用于 p = sqrt(h)

import tensorflow as tf
import numpy as np
c = 3
h = 1024
p = 32

image = tf.random_normal([h,h,c])
patch_size = [1,p,p,1]
patches = tf.extract_image_patches([image],
   patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [h, p, p, c])
reconstructed = tf.reshape(patches, [1, h, h, c])
rec_new = tf.space_to_depth(reconstructed,p)
rec_new = tf.reshape(rec_new,[h,h,c])

sess = tf.Session()
I,P,R_n = sess.run([image,patches,rec_new])
print(I.shape)
print(P.shape)
print(R_n.shape)
err = np.sum((R_n-I)**2)
print(err)

输出:

(1024, 1024, 3)
(1024, 32, 32, 3)
(1024, 1024, 3)
0.0

更新 2

从extract_image_patches 的输出重建似乎很困难。 使用其他函数来提取补丁并反转过程以重建这似乎更容易。

import tensorflow as tf
import numpy as np
c = 3
h = 1024
p = 128


image = tf.random_normal([1,h,h,c])

# Image to Patches Conversion
pad = [[0,0],[0,0]]
patches = tf.space_to_batch_nd(image,[p,p],pad)
patches = tf.split(patches,p*p,0)
patches = tf.stack(patches,3)
patches = tf.reshape(patches,[(h/p)**2,p,p,c])

# Do processing on patches
# Using patches here to reconstruct
patches_proc = tf.reshape(patches,[1,h/p,h/p,p*p,c])
patches_proc = tf.split(patches_proc,p*p,3)
patches_proc = tf.stack(patches_proc,axis=0)
patches_proc = tf.reshape(patches_proc,[p*p,h/p,h/p,c])

reconstructed = tf.batch_to_space_nd(patches_proc,[p, p],pad)

sess = tf.Session()
I,P,R_n = sess.run([image,patches,reconstructed])
print(I.shape)
print(P.shape)
print(R_n.shape)
err = np.sum((R_n-I)**2)
print(err)

输出:

(1, 1024, 1024, 3)
(64, 128, 128, 3)
(1, 1024, 1024, 3)
0.0

你可以在这里看到其他很酷的张量转换函数: https : //www.tensorflow.org/api_guides/python/array_ops

tf.extract_image_patches很难使用,因为它在后台做了很多事情。

如果您只需要不重叠,那么我们自己编写它会容易得多。 您可以通过反转image_to_patches所有操作来重建完整图像。

代码示例(绘制原始图像和补丁):

import tensorflow as tf
from skimage import io
import matplotlib.pyplot as plt


def image_to_patches(image, patch_height, patch_width):
    # resize image so that it's dimensions are dividable by patch_height and patch_width
    image_height = tf.cast(tf.shape(image)[0], dtype=tf.float32)
    image_width = tf.cast(tf.shape(image)[1], dtype=tf.float32)
    height = tf.cast(tf.ceil(image_height / patch_height) * patch_height, dtype=tf.int32)
    width = tf.cast(tf.ceil(image_width / patch_width) * patch_width, dtype=tf.int32)

    num_rows = height // patch_height
    num_cols = width // patch_width
    # make zero-padding
    image = tf.squeeze(tf.image.resize_image_with_crop_or_pad(image, height, width))

    # get slices along the 0-th axis
    image = tf.reshape(image, [num_rows, patch_height, width, -1])
    # h/patch_h, w, patch_h, c
    image = tf.transpose(image, [0, 2, 1, 3])
    # get slices along the 1-st axis
    # h/patch_h, w/patch_w, patch_w,patch_h, c
    image = tf.reshape(image, [num_rows, num_cols, patch_width, patch_height, -1])
    # num_patches, patch_w, patch_h, c
    image = tf.reshape(image, [num_rows * num_cols, patch_width, patch_height, -1])
    # num_patches, patch_h, patch_w, c
    return tf.transpose(image, [0, 2, 1, 3])


image = io.imread('http://www.petful.com/wp-content/uploads/2011/09/slow-blinking-cat.jpg')
print('Original image shape:', image.shape)
tile_size = 200
image = tf.constant(image)
tiles = image_to_patches(image, tile_size, tile_size)

sess = tf.Session()
I, tiles = sess.run([image, tiles])
print(I.shape)
print(tiles.shape)


plt.figure(figsize=(1 * (4 + 1), 5))
plt.subplot(5, 1, 1)
plt.imshow(I)
plt.title('original')
plt.axis('off')
for i, tile in enumerate(tiles):
    plt.subplot(5, 5, 5 + 1 + i)
    plt.imshow(tile)
    plt.title(str(i))
    plt.axis('off')
plt.show()
_,n_row,n_col,n_channel = x.shape
n_patch = n_row*n_col // (patch_size**2) #assume square patch

patches = tf.image.extract_patches(image,sizes=[1,patch_size,patch_size,1],strides=[1,patch_size,patch_size,1],rates=[1, 1, 1, 1],padding='VALID')
patches = tf.reshape(patches,[n_patch,patch_size,patch_size,n_channel])

rows = tf.split(patches,n_col//patch_size,axis=0)
rows = [tf.concat(tf.unstack(x),axis=1) for x in rows] 

reconstructed = tf.concat(rows,axis=0)

我不知道这是否是一个有效的实现,但它有效!

为了专门解决最初的问题,即“使用extract_image_patches 后重建图像”,我建议使用tf.scatter_nd()并构建分层图像。 即使在提取的补丁存在重叠或图像样本不足的情况下,这也将起作用。 这是我提出的解决方案。

import cv2
import numpy as np
import tensorflow as tf

# Function to extract patches using 'extract_image_patches'
def img_to_patches(raw_input, _patch_size=(128, 128), _stride=100):

    with tf.variable_scope('im2_patches'):
        patches = tf.image.extract_image_patches(
            images=raw_input,
            ksizes=[1, _patch_size[0], _patch_size[1], 1],
            strides=[1, _stride, _stride, 1],
            rates=[1, 1, 1, 1],
            padding='SAME'
        )

        h = tf.shape(patches)[1]
        w = tf.shape(patches)[2]
        patches = tf.reshape(patches, (patches.shape[0], -1, _patch_size[0], _patch_size[1], 3))
    return patches, (h, w)


# Function to reconstruct image
def patches_to_img(update, _block_shape, _stride=100):
    with tf.variable_scope('patches2im'):
        _h = _block_shape[0]
        _w = _block_shape[1]

        bs = tf.shape(update)[0]  # batch size
        np = tf.shape(update)[1]  # number of patches
        ps_h = tf.shape(update)[2]  # patch height
        ps_w = tf.shape(update)[3]  # patch width
        col_ch = tf.shape(update)[4]  # Colour channel count

        wout = (_w - 1) * _stride + ps_w  # Recalculate output shape of "extract_image_patches" including padded pixels
        hout = (_h - 1) * _stride + ps_h  # Recalculate output shape of "extract_image_patches" including padded pixels

        x, y = tf.meshgrid(tf.range(ps_w), tf.range(ps_h))
        x = tf.reshape(x, (1, 1, ps_h, ps_w, 1, 1))
        y = tf.reshape(y, (1, 1, ps_h, ps_w, 1, 1))
        xstart, ystart = tf.meshgrid(tf.range(0, (wout - ps_w) + 1, _stride),
                                     tf.range(0, (hout - ps_h) + 1, _stride))

        bb = tf.zeros((1, np, ps_h, ps_w, col_ch, 1), dtype=tf.int32) + tf.reshape(tf.range(bs), (-1, 1, 1, 1, 1, 1))  #  batch indices
        yy = tf.zeros((bs, 1, 1, 1, col_ch, 1), dtype=tf.int32) + y + tf.reshape(ystart, (1, -1, 1, 1, 1, 1))  # y indices
        xx = tf.zeros((bs, 1, 1, 1, col_ch, 1), dtype=tf.int32) + x + tf.reshape(xstart, (1, -1, 1, 1, 1, 1))  # x indices
        cc = tf.zeros((bs, np, ps_h, ps_w, 1, 1), dtype=tf.int32) + tf.reshape(tf.range(col_ch), (1, 1, 1, 1, -1, 1))  # color indices
        dd = tf.zeros((bs, 1, ps_h, ps_w, col_ch, 1), dtype=tf.int32) + tf.reshape(tf.range(np), (1, -1, 1, 1, 1, 1))  # shift indices

        idx = tf.concat([bb, yy, xx, cc, dd], -1)

        stratified_img = tf.scatter_nd(idx, update, (bs, hout, wout, col_ch, np))
        stratified_img = tf.transpose(stratified_img, (0, 4, 1, 2, 3))

        stratified_img_count = tf.scatter_nd(idx, tf.ones_like(update), (bs, hout, wout, col_ch, np))
        stratified_img_count = tf.transpose(stratified_img_count, (0, 4, 1, 2, 3))

        with tf.variable_scope("consolidate"):
            sum_stratified_img = tf.reduce_sum(stratified_img, axis=1)
            stratified_img_count = tf.reduce_sum(stratified_img_count, axis=1)
            reconstructed_img = tf.divide(sum_stratified_img, stratified_img_count)

        return reconstructed_img, stratified_img



if __name__ == "__main__":

    # load initial image
    image_org = cv2.imread('orig_img.jpg')
    # Add batch dimension
    image = np.expand_dims(image_org, axis=0)

    # set parameters
    patch_size = (228, 228)
    stride = 200

    input_img = tf.placeholder(dtype=tf.float32, shape=image.shape, name="input_img")

    # Extract patches using "extract_image_patches()"
    extracted_patches, block_shape = img_to_patches(input_img, _patch_size=patch_size, _stride=stride)
    # block_shape is the number of patches extracted in the x and in the y dimension
    # extracted_patches.shape = (1, block_shape[0] * block_shape[1], patch_size[0], patch_size[1], 3)

    reconstructed_img, stratified_img = patches_to_img(extracted_patches, block_shape, stride)  # Reconstruct Image


    with tf.Session() as sess:
        ep, bs, ri, si = sess.run([extracted_patches, block_shape, reconstructed_img, stratified_img], feed_dict={input_img: image})
        # print(bs)
    si = si.astype(np.int32)

    # Show reconstructed image
    cv2.imshow('sd', ri[0, :, :, :].astype(np.float32) / 255)
    cv2.waitKey(0)

    # Show stratified images
    for i in range(si.shape[1]):

        im_1 = si[0, i, :, :, :]
        cv2.imshow('sd', im_1.astype(np.float32)/255)

上述解决方案应该适用于任意颜色通道尺寸的批量图像。

此代码适用于您的特定情况,以及图像为方形、具有方形内核且图像大小可被内核大小整除的情况。

我没有针对其他情况对其进行测试。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt


size = 1024
k_size = 32
axes_1_2_size = int(np.sqrt((size * size) / (k_size * k_size)))

# Define a placeholder for image (or load it directly if you prefer) 
img = tf.placeholder(tf.int32, shape=(1, size, size, 3))

# Extract patches
patches = tf.image.extract_image_patches(img, ksizes=[1, k_size, k_size, 1], 
                                         strides=[1, k_size, k_size, 1], 
                                         rates=[1, 1, 1, 1], padding='VALID')

# Reconstruct the image back from the patches
# First separate out the channel dimension
reconstruct = tf.reshape(patches, (1, axes_1_2_size, axes_1_2_size, k_size, k_size, 3)) 
# Tranpose the axes (I got this axes tuple for transpose via experimentation)
reconstruct = tf.transpose(reconstruct, (0, 1, 3, 2, 4, 5))
# Reshape back
reconstruct = tf.reshape(reconstruct, (size, size, 3))

im_arr = # load image with shape (size, size, 3)

# Run the operations
with tf.Session() as sess:
    ps, r = sess.run([patches, reconstruct], feed_dict={img:[im_arr]})

# Plot the reconstructed image to verify
plt.imshow(r)

如果您不做重叠块,Tf 2.0 用户可以使用 space_to_depth 和 depth_to_space。

我可能有点晚了,但是因为我让它与TF-2.3一起工作,它可能对其他人有用。 以下代码适用于非重叠补丁 - 单通道或多通道:

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers

class PatchesToImage(layers.Layer):
    def __init__(self, imgh, imgw, imgc, patsz, is_squeeze=True, **kwargs):
        super(PatchesToImage, self).__init__(**kwargs)
        self.H = (imgh // patsz) * patsz
        self.W = (imgw // patsz) * patsz
        self.C = imgc
        self.P = patsz
        self.is_squeeze = is_squeeze
        
    def call(self, inputs):
        bs = tf.shape(inputs)[0]
        rows, cols = self.H // self.P, self.W // self.P
        patches = tf.reshape(inputs, [bs, rows, cols, -1, self.C])
        pats_by_clist = tf.unstack(patches, axis=-1)
        def tile_patches(ii):
            pats = pats_by_clist[ii]
            img = tf.nn.depth_to_space(pats, self.P)
            return img 
        img = tf.map_fn(fn=tile_patches, elems=tf.range(self.C), fn_output_signature=inputs.dtype)
        img = tf.squeeze(img, axis=-1)
        img = tf.transpose(img, perm=[1,2,3,0])
        C = tf.shape(img)[-1]
        img = tf.cond(tf.logical_and(tf.constant(self.is_squeeze), C==1), 
                      lambda: tf.squeeze(img, axis=-1), lambda: img)
        return img

使用 im2col 和 col2im 方法实现

import numpy as np import keras import tensorflow as tf import matplotlib.pyplot as plt

class ImPatch(): def init (self): pass

def save_image(self, img, N=None):
    plt.imshow(img)
    plt.savefig(str(N))
    plt.clf()

def get_indices(self, X_shape, HF, WF, stride, pad):
    # get input size
    m, n_C, n_H, n_W = X_shape

    # get output size
    out_h = int((n_H + 2 * pad - HF) / stride) + 1
    out_w = int((n_W + 2 * pad - WF) / stride) + 1

    # ----Compute matrix of index i----

    # Level 1 vector.
    level1 = np.repeat(np.arange(HF), WF)
    # Duplicate for the other channels.
    level1 = np.tile(level1, n_C)
    # Create a vector with an increase by 1 at each level.
    everyLevels = stride * np.repeat(np.arange(out_h), out_w)
    # Create matrix of index i at every levels for each channel.
    i = level1.reshape(-1, 1) + everyLevels.reshape(1, -1)

    # ----Compute matrix of index j----
    
    # Slide 1 vector.
    slide1 = np.tile(np.arange(WF), HF)
    # Duplicate for the other channels.
    slide1 = np.tile(slide1, n_C)
    # Create a vector with an increase by 1 at each slide.
    everySlides = stride * np.tile(np.arange(out_w), out_h)
    # Create matrix of index j at every slides for each channel.
    j = slide1.reshape(-1, 1) + everySlides.reshape(1, -1)

    # ----Compute matrix of index d----

    # This is to mark delimitation for each channel
    # during multi-dimensional arrays indexing.
    d = np.repeat(np.arange(n_C), HF * WF).reshape(-1, 1)

    return i, j, d

def im2col(self, X, HF, WF, stride, pad):
    # Padding
    X_padded = np.pad(X, ((0,0), (0,0), (pad, pad), (pad, pad)), mode='constant')
    i, j, d = self.get_indices(X.shape, HF, WF, stride, pad)
    # Multi-dimensional arrays indexing.
    cols = X_padded[:, d, i, j]
    cols = np.concatenate(cols, axis=-1)
    return cols

def col2im(self, col, X_shape, HF, WF, stride, pad):
    # Get input size
    N, D, H, W = X_shape
    # Add padding if needed.
    H_padded, W_padded = H + 2 * pad, W + 2 * pad
    X_padded = np.zeros((N, D, H_padded, W_padded))
    
    # Index matrices, necessary to transform our input image into a matrix. 
    i, j, d = self.get_indices(X_shape, HF, WF, stride, pad)
    # Retrieve batch dimension by spliting dX_col N times: (X, Y) => (N, X, Y)
    dX_col_reshaped = np.array(np.hsplit(col, N))
    # Reshape our matrix back to image.
    # slice(None) is used to produce the [::] effect which means "for every elements".
    np.add.at(X_padded, (slice(None), d, i, j), dX_col_reshaped)
    # Remove padding from new image if needed.
    if pad == 0:
        return X_padded
    elif type(pad) is int:
        return X_padded[pad:-pad, pad:-pad, :, :]


def get_patches(self, x, HF, WF, stride, verbose=False):
    x_patches = tf.image.extract_patches(x, sizes=[1, HF, WF, 1], strides=[1, stride, stride, 1], rates=[1, 1, 1, 1], padding='VALID')
    if verbose == True:
        print (x_patches.shape, 'x_patches shape')
    
    return x_patches


def get_img(self, x_patches, x_shape, HF, WF, stride, verbose=False):
    x_patches_T = np.transpose(x_patches, (0, 3, 1, 2))
    x_col = self.im2col(X=x_patches_T, HF=1, WF=1, stride=1, pad=0)
    if verbose == True:
        print (x_col.shape, 'x_col shape')

    x_shape = (x_shape[0], x_shape[3], x_shape[1], x_shape[2])
    x_reconstruct = self.col2im(col=x_col, X_shape=x_shape, HF=HF, WF=WF, stride=stride, pad=0)
    x_reconstruct_T = np.transpose(x_reconstruct, (0, 2, 3, 1))
    if verbose == True:
        print (x_reconstruct.shape, 'x_reconstruct shape')
        print (x_reconstruct_T.shape, 'x_reconstruct_T shape')

    return x_reconstruct_T

def test(self, x, HF, WF, stride, save=True, verbose=True):
    x_patches = self.get_patches(x, HF=HF, WF=WF, stride=stride, verbose=verbose)
    x_reconstruct = self.get_img(x_patches, x_shape=x.shape, HF=HF, WF=WF, stride=stride, verbose=verbose)

    if save == True:
        idx = np.random.randint(0, x.shape[0])
        self.save_image(img=x[idx].reshape(28, 28), N=0)
        self.save_image(img=x_reconstruct[idx].reshape(28, 28), N=1)

    return x_reconstruct

impatch = ImPatch()

(x, ), ( , _) = keras.datasets.mnist.load_data() x = np.expand_dims(x[0:10],axis=-1)

HF, WF, stride = 4, 4, 4 impatch.test(x, HF=HF, WF=WF, stride=stride)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM