简体   繁体   中英

Reconstructing an image after using extract_image_patches

I have an autoencoder that takes an image as an input and produces a new image as an output.

The input image (1x1024x1024x3) is split into patches (1024x32x32x3) before being fed to the network.

Once I have the output, also a batch of patches size 1024x32x32x3, I want to be able to reconstruct a 1024x1024x3 image. I thought I had this sussed by simply reshaping, but here's what happened.

First, the image as read by Tensorflow:输入图像

I patched the image with the following code

patch_size = [1, 32, 32, 1]
patches = tf.extract_image_patches([image],
    patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [1024, 32, 32, 3])

Here are a couple of patches from this image:

修补输入 #168修补输入 #169

But it's when I reshape this patch data back into an image that things go pear-shaped.

reconstructed = tf.reshape(patches, [1, 1024, 1024, 3])
converted = tf.image.convert_image_dtype(reconstructed, tf.uint8)
encoded = tf.image.encode_png(converted)

重构输出

In this example, no processing has been done between patching and reconstructing. I have made a version of the code you can use to test this behaviour. To use it, run the following:

echo "/path/to/test-image.png" > inputs.txt
mkdir images
python3 image_test.py inputs.txt images

The code will make one input image, one patch image, and one output image for each of the 1024 patches in each input image, so comment out the lines that create input and output images if you're only concerned in saving all the patches.

Somebody please explain what happened :(

Since I also struggled with this, I post a solution that might be useful to others. The trick is to realize that the inverse of tf.extract_image_patches is its gradient, as suggested here . Since the gradient of this op is implemented in Tensorflow, it is easy to build the reconstruction function:

import tensorflow as tf
from keras import backend as K
import numpy as np

def extract_patches(x):
    return tf.extract_image_patches(
        x,
        (1, 3, 3, 1),
        (1, 1, 1, 1),
        (1, 1, 1, 1),
        padding="VALID"
    )

def extract_patches_inverse(x, y):
    _x = tf.zeros_like(x)
    _y = extract_patches(_x)
    grad = tf.gradients(_y, _x)[0]
    # Divide by grad, to "average" together the overlapping patches
    # otherwise they would simply sum up
    return tf.gradients(_y, _x, grad_ys=y)[0] / grad

# Generate 10 fake images, last dimension can be different than 3
images = np.random.random((10, 28, 28, 3)).astype(np.float32)
# Extract patches
patches = extract_patches(images)
# Reconstruct image
# Notice that original images are only passed to infer the right shape
images_reconstructed = extract_patches_inverse(images, patches) 

# Compare with original (evaluating tf.Tensor into a numpy array)
# Here using Keras session
images_r = images_reconstructed.eval(session=K.get_session())

print (np.sum(np.square(images - images_r))) 
# 2.3820458e-11

Use Update#2 - One small example for your task: (TF 1.0)

Considering image of size (4,4,1) converted to patches of size (4,2,2,1) and reconstructed them back to image.

import tensorflow as tf
image = tf.constant([[[1],   [2],  [3],  [4]],
                 [[5],   [6],  [7],  [8]],
                 [[9],  [10], [11],  [12]],
                [[13], [14], [15],  [16]]])

patch_size = [1,2,2,1]
patches = tf.extract_image_patches([image],
    patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [4, 2, 2, 1])
reconstructed = tf.reshape(patches, [1, 4, 4, 1])
rec_new = tf.space_to_depth(reconstructed,2)
rec_new = tf.reshape(rec_new,[4,4,1])

sess = tf.Session()
I,P,R_n = sess.run([image,patches,rec_new])
print(I)
print(I.shape)
print(P.shape)
print(R_n)
print(R_n.shape)

Output:

[[[ 1][ 2][ 3][ 4]]
  [[ 5][ 6][ 7][ 8]]
  [[ 9][10][11][12]]
  [[13][14][15][16]]]
(4, 4, 1)
(4, 2, 2, 1)
[[[ 1][ 2][ 3][ 4]]
  [[ 5][ 6][ 7][ 8]]
  [[ 9][10][11][12]]
  [[13][14][15][16]]]
(4,4,1)

Update - for 3 channels (debugging..)

working only for p = sqrt(h)

import tensorflow as tf
import numpy as np
c = 3
h = 1024
p = 32

image = tf.random_normal([h,h,c])
patch_size = [1,p,p,1]
patches = tf.extract_image_patches([image],
   patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [h, p, p, c])
reconstructed = tf.reshape(patches, [1, h, h, c])
rec_new = tf.space_to_depth(reconstructed,p)
rec_new = tf.reshape(rec_new,[h,h,c])

sess = tf.Session()
I,P,R_n = sess.run([image,patches,rec_new])
print(I.shape)
print(P.shape)
print(R_n.shape)
err = np.sum((R_n-I)**2)
print(err)

Output :

(1024, 1024, 3)
(1024, 32, 32, 3)
(1024, 1024, 3)
0.0

Update 2

Reconstructing from output of extract_image_patches seems difficult. Used other functions to extract patches and reverse the process to reconstruct which seems easier.

import tensorflow as tf
import numpy as np
c = 3
h = 1024
p = 128


image = tf.random_normal([1,h,h,c])

# Image to Patches Conversion
pad = [[0,0],[0,0]]
patches = tf.space_to_batch_nd(image,[p,p],pad)
patches = tf.split(patches,p*p,0)
patches = tf.stack(patches,3)
patches = tf.reshape(patches,[(h/p)**2,p,p,c])

# Do processing on patches
# Using patches here to reconstruct
patches_proc = tf.reshape(patches,[1,h/p,h/p,p*p,c])
patches_proc = tf.split(patches_proc,p*p,3)
patches_proc = tf.stack(patches_proc,axis=0)
patches_proc = tf.reshape(patches_proc,[p*p,h/p,h/p,c])

reconstructed = tf.batch_to_space_nd(patches_proc,[p, p],pad)

sess = tf.Session()
I,P,R_n = sess.run([image,patches,reconstructed])
print(I.shape)
print(P.shape)
print(R_n.shape)
err = np.sum((R_n-I)**2)
print(err)

Output:

(1, 1024, 1024, 3)
(64, 128, 128, 3)
(1, 1024, 1024, 3)
0.0

You could see other cool tensor transformation functions here : https://www.tensorflow.org/api_guides/python/array_ops

tf.extract_image_patches is quiet difficult to use, as it does a lot of stuff in background.

If you just need non-overlaping, then it's much easier to write it ourself. You can reconstruct full image by inverting all operations in image_to_patches .

Code sample (plots original image and patches):

import tensorflow as tf
from skimage import io
import matplotlib.pyplot as plt


def image_to_patches(image, patch_height, patch_width):
    # resize image so that it's dimensions are dividable by patch_height and patch_width
    image_height = tf.cast(tf.shape(image)[0], dtype=tf.float32)
    image_width = tf.cast(tf.shape(image)[1], dtype=tf.float32)
    height = tf.cast(tf.ceil(image_height / patch_height) * patch_height, dtype=tf.int32)
    width = tf.cast(tf.ceil(image_width / patch_width) * patch_width, dtype=tf.int32)

    num_rows = height // patch_height
    num_cols = width // patch_width
    # make zero-padding
    image = tf.squeeze(tf.image.resize_image_with_crop_or_pad(image, height, width))

    # get slices along the 0-th axis
    image = tf.reshape(image, [num_rows, patch_height, width, -1])
    # h/patch_h, w, patch_h, c
    image = tf.transpose(image, [0, 2, 1, 3])
    # get slices along the 1-st axis
    # h/patch_h, w/patch_w, patch_w,patch_h, c
    image = tf.reshape(image, [num_rows, num_cols, patch_width, patch_height, -1])
    # num_patches, patch_w, patch_h, c
    image = tf.reshape(image, [num_rows * num_cols, patch_width, patch_height, -1])
    # num_patches, patch_h, patch_w, c
    return tf.transpose(image, [0, 2, 1, 3])


image = io.imread('http://www.petful.com/wp-content/uploads/2011/09/slow-blinking-cat.jpg')
print('Original image shape:', image.shape)
tile_size = 200
image = tf.constant(image)
tiles = image_to_patches(image, tile_size, tile_size)

sess = tf.Session()
I, tiles = sess.run([image, tiles])
print(I.shape)
print(tiles.shape)


plt.figure(figsize=(1 * (4 + 1), 5))
plt.subplot(5, 1, 1)
plt.imshow(I)
plt.title('original')
plt.axis('off')
for i, tile in enumerate(tiles):
    plt.subplot(5, 5, 5 + 1 + i)
    plt.imshow(tile)
    plt.title(str(i))
    plt.axis('off')
plt.show()
_,n_row,n_col,n_channel = x.shape
n_patch = n_row*n_col // (patch_size**2) #assume square patch

patches = tf.image.extract_patches(image,sizes=[1,patch_size,patch_size,1],strides=[1,patch_size,patch_size,1],rates=[1, 1, 1, 1],padding='VALID')
patches = tf.reshape(patches,[n_patch,patch_size,patch_size,n_channel])

rows = tf.split(patches,n_col//patch_size,axis=0)
rows = [tf.concat(tf.unstack(x),axis=1) for x in rows] 

reconstructed = tf.concat(rows,axis=0)

I don't know if this is an efficient implementation but it works!

To specifically address the initial question, which is 'Reconstructing an image after using extract_image_patches', I propose using tf.scatter_nd() and building a stratified image. This will work even in a situation where there is an overlap in the extracted patches or the image is under-sample. Here is my proposed solution.

import cv2
import numpy as np
import tensorflow as tf

# Function to extract patches using 'extract_image_patches'
def img_to_patches(raw_input, _patch_size=(128, 128), _stride=100):

    with tf.variable_scope('im2_patches'):
        patches = tf.image.extract_image_patches(
            images=raw_input,
            ksizes=[1, _patch_size[0], _patch_size[1], 1],
            strides=[1, _stride, _stride, 1],
            rates=[1, 1, 1, 1],
            padding='SAME'
        )

        h = tf.shape(patches)[1]
        w = tf.shape(patches)[2]
        patches = tf.reshape(patches, (patches.shape[0], -1, _patch_size[0], _patch_size[1], 3))
    return patches, (h, w)


# Function to reconstruct image
def patches_to_img(update, _block_shape, _stride=100):
    with tf.variable_scope('patches2im'):
        _h = _block_shape[0]
        _w = _block_shape[1]

        bs = tf.shape(update)[0]  # batch size
        np = tf.shape(update)[1]  # number of patches
        ps_h = tf.shape(update)[2]  # patch height
        ps_w = tf.shape(update)[3]  # patch width
        col_ch = tf.shape(update)[4]  # Colour channel count

        wout = (_w - 1) * _stride + ps_w  # Recalculate output shape of "extract_image_patches" including padded pixels
        hout = (_h - 1) * _stride + ps_h  # Recalculate output shape of "extract_image_patches" including padded pixels

        x, y = tf.meshgrid(tf.range(ps_w), tf.range(ps_h))
        x = tf.reshape(x, (1, 1, ps_h, ps_w, 1, 1))
        y = tf.reshape(y, (1, 1, ps_h, ps_w, 1, 1))
        xstart, ystart = tf.meshgrid(tf.range(0, (wout - ps_w) + 1, _stride),
                                     tf.range(0, (hout - ps_h) + 1, _stride))

        bb = tf.zeros((1, np, ps_h, ps_w, col_ch, 1), dtype=tf.int32) + tf.reshape(tf.range(bs), (-1, 1, 1, 1, 1, 1))  #  batch indices
        yy = tf.zeros((bs, 1, 1, 1, col_ch, 1), dtype=tf.int32) + y + tf.reshape(ystart, (1, -1, 1, 1, 1, 1))  # y indices
        xx = tf.zeros((bs, 1, 1, 1, col_ch, 1), dtype=tf.int32) + x + tf.reshape(xstart, (1, -1, 1, 1, 1, 1))  # x indices
        cc = tf.zeros((bs, np, ps_h, ps_w, 1, 1), dtype=tf.int32) + tf.reshape(tf.range(col_ch), (1, 1, 1, 1, -1, 1))  # color indices
        dd = tf.zeros((bs, 1, ps_h, ps_w, col_ch, 1), dtype=tf.int32) + tf.reshape(tf.range(np), (1, -1, 1, 1, 1, 1))  # shift indices

        idx = tf.concat([bb, yy, xx, cc, dd], -1)

        stratified_img = tf.scatter_nd(idx, update, (bs, hout, wout, col_ch, np))
        stratified_img = tf.transpose(stratified_img, (0, 4, 1, 2, 3))

        stratified_img_count = tf.scatter_nd(idx, tf.ones_like(update), (bs, hout, wout, col_ch, np))
        stratified_img_count = tf.transpose(stratified_img_count, (0, 4, 1, 2, 3))

        with tf.variable_scope("consolidate"):
            sum_stratified_img = tf.reduce_sum(stratified_img, axis=1)
            stratified_img_count = tf.reduce_sum(stratified_img_count, axis=1)
            reconstructed_img = tf.divide(sum_stratified_img, stratified_img_count)

        return reconstructed_img, stratified_img



if __name__ == "__main__":

    # load initial image
    image_org = cv2.imread('orig_img.jpg')
    # Add batch dimension
    image = np.expand_dims(image_org, axis=0)

    # set parameters
    patch_size = (228, 228)
    stride = 200

    input_img = tf.placeholder(dtype=tf.float32, shape=image.shape, name="input_img")

    # Extract patches using "extract_image_patches()"
    extracted_patches, block_shape = img_to_patches(input_img, _patch_size=patch_size, _stride=stride)
    # block_shape is the number of patches extracted in the x and in the y dimension
    # extracted_patches.shape = (1, block_shape[0] * block_shape[1], patch_size[0], patch_size[1], 3)

    reconstructed_img, stratified_img = patches_to_img(extracted_patches, block_shape, stride)  # Reconstruct Image


    with tf.Session() as sess:
        ep, bs, ri, si = sess.run([extracted_patches, block_shape, reconstructed_img, stratified_img], feed_dict={input_img: image})
        # print(bs)
    si = si.astype(np.int32)

    # Show reconstructed image
    cv2.imshow('sd', ri[0, :, :, :].astype(np.float32) / 255)
    cv2.waitKey(0)

    # Show stratified images
    for i in range(si.shape[1]):

        im_1 = si[0, i, :, :, :]
        cv2.imshow('sd', im_1.astype(np.float32)/255)

The above solution should work for batched images of arbirary color channel dimensions.

This code works for your specific case, as well as for cases when the images are square, with a square kernel and the image size is divisible by the kernel size.

I did not test it for other cases.

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt


size = 1024
k_size = 32
axes_1_2_size = int(np.sqrt((size * size) / (k_size * k_size)))

# Define a placeholder for image (or load it directly if you prefer) 
img = tf.placeholder(tf.int32, shape=(1, size, size, 3))

# Extract patches
patches = tf.image.extract_image_patches(img, ksizes=[1, k_size, k_size, 1], 
                                         strides=[1, k_size, k_size, 1], 
                                         rates=[1, 1, 1, 1], padding='VALID')

# Reconstruct the image back from the patches
# First separate out the channel dimension
reconstruct = tf.reshape(patches, (1, axes_1_2_size, axes_1_2_size, k_size, k_size, 3)) 
# Tranpose the axes (I got this axes tuple for transpose via experimentation)
reconstruct = tf.transpose(reconstruct, (0, 1, 3, 2, 4, 5))
# Reshape back
reconstruct = tf.reshape(reconstruct, (size, size, 3))

im_arr = # load image with shape (size, size, 3)

# Run the operations
with tf.Session() as sess:
    ps, r = sess.run([patches, reconstruct], feed_dict={img:[im_arr]})

# Plot the reconstructed image to verify
plt.imshow(r)

如果您不做重叠块,Tf 2.0 用户可以使用 space_to_depth 和 depth_to_space。

I may be a bit late, but since I got it working with TF-2.3 , it might prove useful for others. The following code works for non-overlapping patches - single or multi-channel:

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers

class PatchesToImage(layers.Layer):
    def __init__(self, imgh, imgw, imgc, patsz, is_squeeze=True, **kwargs):
        super(PatchesToImage, self).__init__(**kwargs)
        self.H = (imgh // patsz) * patsz
        self.W = (imgw // patsz) * patsz
        self.C = imgc
        self.P = patsz
        self.is_squeeze = is_squeeze
        
    def call(self, inputs):
        bs = tf.shape(inputs)[0]
        rows, cols = self.H // self.P, self.W // self.P
        patches = tf.reshape(inputs, [bs, rows, cols, -1, self.C])
        pats_by_clist = tf.unstack(patches, axis=-1)
        def tile_patches(ii):
            pats = pats_by_clist[ii]
            img = tf.nn.depth_to_space(pats, self.P)
            return img 
        img = tf.map_fn(fn=tile_patches, elems=tf.range(self.C), fn_output_signature=inputs.dtype)
        img = tf.squeeze(img, axis=-1)
        img = tf.transpose(img, perm=[1,2,3,0])
        C = tf.shape(img)[-1]
        img = tf.cond(tf.logical_and(tf.constant(self.is_squeeze), C==1), 
                      lambda: tf.squeeze(img, axis=-1), lambda: img)
        return img

Implemented using im2col and col2im method

import numpy as np import keras import tensorflow as tf import matplotlib.pyplot as plt

class ImPatch(): def init (self): pass

def save_image(self, img, N=None):
    plt.imshow(img)
    plt.savefig(str(N))
    plt.clf()

def get_indices(self, X_shape, HF, WF, stride, pad):
    # get input size
    m, n_C, n_H, n_W = X_shape

    # get output size
    out_h = int((n_H + 2 * pad - HF) / stride) + 1
    out_w = int((n_W + 2 * pad - WF) / stride) + 1

    # ----Compute matrix of index i----

    # Level 1 vector.
    level1 = np.repeat(np.arange(HF), WF)
    # Duplicate for the other channels.
    level1 = np.tile(level1, n_C)
    # Create a vector with an increase by 1 at each level.
    everyLevels = stride * np.repeat(np.arange(out_h), out_w)
    # Create matrix of index i at every levels for each channel.
    i = level1.reshape(-1, 1) + everyLevels.reshape(1, -1)

    # ----Compute matrix of index j----
    
    # Slide 1 vector.
    slide1 = np.tile(np.arange(WF), HF)
    # Duplicate for the other channels.
    slide1 = np.tile(slide1, n_C)
    # Create a vector with an increase by 1 at each slide.
    everySlides = stride * np.tile(np.arange(out_w), out_h)
    # Create matrix of index j at every slides for each channel.
    j = slide1.reshape(-1, 1) + everySlides.reshape(1, -1)

    # ----Compute matrix of index d----

    # This is to mark delimitation for each channel
    # during multi-dimensional arrays indexing.
    d = np.repeat(np.arange(n_C), HF * WF).reshape(-1, 1)

    return i, j, d

def im2col(self, X, HF, WF, stride, pad):
    # Padding
    X_padded = np.pad(X, ((0,0), (0,0), (pad, pad), (pad, pad)), mode='constant')
    i, j, d = self.get_indices(X.shape, HF, WF, stride, pad)
    # Multi-dimensional arrays indexing.
    cols = X_padded[:, d, i, j]
    cols = np.concatenate(cols, axis=-1)
    return cols

def col2im(self, col, X_shape, HF, WF, stride, pad):
    # Get input size
    N, D, H, W = X_shape
    # Add padding if needed.
    H_padded, W_padded = H + 2 * pad, W + 2 * pad
    X_padded = np.zeros((N, D, H_padded, W_padded))
    
    # Index matrices, necessary to transform our input image into a matrix. 
    i, j, d = self.get_indices(X_shape, HF, WF, stride, pad)
    # Retrieve batch dimension by spliting dX_col N times: (X, Y) => (N, X, Y)
    dX_col_reshaped = np.array(np.hsplit(col, N))
    # Reshape our matrix back to image.
    # slice(None) is used to produce the [::] effect which means "for every elements".
    np.add.at(X_padded, (slice(None), d, i, j), dX_col_reshaped)
    # Remove padding from new image if needed.
    if pad == 0:
        return X_padded
    elif type(pad) is int:
        return X_padded[pad:-pad, pad:-pad, :, :]


def get_patches(self, x, HF, WF, stride, verbose=False):
    x_patches = tf.image.extract_patches(x, sizes=[1, HF, WF, 1], strides=[1, stride, stride, 1], rates=[1, 1, 1, 1], padding='VALID')
    if verbose == True:
        print (x_patches.shape, 'x_patches shape')
    
    return x_patches


def get_img(self, x_patches, x_shape, HF, WF, stride, verbose=False):
    x_patches_T = np.transpose(x_patches, (0, 3, 1, 2))
    x_col = self.im2col(X=x_patches_T, HF=1, WF=1, stride=1, pad=0)
    if verbose == True:
        print (x_col.shape, 'x_col shape')

    x_shape = (x_shape[0], x_shape[3], x_shape[1], x_shape[2])
    x_reconstruct = self.col2im(col=x_col, X_shape=x_shape, HF=HF, WF=WF, stride=stride, pad=0)
    x_reconstruct_T = np.transpose(x_reconstruct, (0, 2, 3, 1))
    if verbose == True:
        print (x_reconstruct.shape, 'x_reconstruct shape')
        print (x_reconstruct_T.shape, 'x_reconstruct_T shape')

    return x_reconstruct_T

def test(self, x, HF, WF, stride, save=True, verbose=True):
    x_patches = self.get_patches(x, HF=HF, WF=WF, stride=stride, verbose=verbose)
    x_reconstruct = self.get_img(x_patches, x_shape=x.shape, HF=HF, WF=WF, stride=stride, verbose=verbose)

    if save == True:
        idx = np.random.randint(0, x.shape[0])
        self.save_image(img=x[idx].reshape(28, 28), N=0)
        self.save_image(img=x_reconstruct[idx].reshape(28, 28), N=1)

    return x_reconstruct

impatch = ImPatch()

(x, ), ( , _) = keras.datasets.mnist.load_data() x = np.expand_dims(x[0:10], axis=-1)

HF, WF, stride = 4, 4, 4 impatch.test(x, HF=HF, WF=WF, stride=stride)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM