简体   繁体   中英

How to implement an image(2D array) sequence sliding window in tensorflow?

Context

We have out data stored in .tfrecord files, X is our training data > 40x40 grey scale images and Y : are labels. Those images are ordered in a sequence (order is important). We would like to input these images using Tensorflows Estimator API for training a neural network model (for example: LSTM) with various time window sizes and shifts using GoogleML.

Question

How to reshape the input string of features into sequences of certain length, eg put 1000 images into one sequence and then perform windowing on these sequences, eg get windows of 50 images, with window shift 25 ?

Current state

We have managed to achieve this (sparse example below) without the first reshape into 1000 length sets, but then the result are windows that span from element 975 of one set to element 25 of the next, which we do not want . We need overlapping windows that span from beginning to the end of every set of 1000 images but must not cross their boundaries.

import tensorflow as tf

# .tfrecord file consisting of data 'X' and labels 'Y'
dataset = tf.data.TFRecordDataset('.tfrecord file')

# define parse function for dataset.map function
def _parse_function(proto):
    # define constants for parsing
    image_size = 40
    num_channels = 1
    num_classes = 3


    # define your tfrecord feature keys and 
    # reshape 1D arrays into 2D arrays (images)
    keys_to_features = {'X': tf.FixedLenFeature([image_size, image_size, num_channels], tf.float32),  # image height, image width, num_channels
                    'Y': tf.FixedLenFeature([], tf.int64)}

    # Load one example
    parsed_features = tf.parse_single_example(proto, keys_to_features)

    # extract image and labels
    image = parsed_features['X']
    labels = tf.cast( parsed_features['Y'], tf.int32 )
    labels = tf.one_hot( labels, depth=num_classes )  # one hot encoding

    return image, labels

# reshape the data into parse format
dataset = dataset.map(_parse_function)

# define dataset parameters
window_size = 50
batch_size = 500
window_shift = int( window_size / 2 )  # 25

# implement sliding window 
dataset = dataset.window(size=window_size, shift=window_shift, drop_remainder=True ).flat_map( lambda x: x.batch(window_size) )

# batch the data
dataset = dataset.batch(batch_size)

# create an iterator
# iterator = dataset.make_one_shot_iterator().get_next()

The iterator above would return for X data a tensor of shape (batch_size, window_size, image_height, image_width, number of channels), in our case (500, 50, 40, 40, 1) and Y as a (500, 3) array.

I managed to do this via filtering out the windows that cross the boundaries. Once you have the parsed features, apply windowing over everything then calculate which windows are the ones that overflow and filter them out:

ds = tf.data.TFRecordDataset( filename )
ds = ds.map( _parse_function )

# apply windowing
ds = ds.window( size=50, shift=25, drop_remainder=True ).flat_map( lambda x, y: tf.data.Dataset.zip( (x.batch(50), y.batch(50)) ) )
# enumerate dataset and filter every 40th window
ds = ds.apply( tf.data.experimental.enumerate_dataset(start=1) ).filter( lambda i, x: tf.not_equal( i % 40, 0) )
# get rid of enumerations
ds = ds.map( lambda i, x: x )

# batching, shuffling etc...
...

Clarification: filtered out is every 40th window because if you have sets of 1000 and window shift of 25, there will be set_len / win_shift = 40 windows and the last one (ie 40th) will be overflowing into next set. Notice also that enumeration starts at 1 so that 0th sample is not taken out, since 0 % x == 0 .

Note that this is more of a hack than a true solution. It works well with 50% overlap, but at other percentages it gets more complicated to calculate the indices to throw out (in case of >50% overlap, more than one window overflows into the next set, therefore multiple filters would be required).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM