简体   繁体   English

如何重塑 Tensorflow 数据集中的数据?

[英]How to reshape data in Tensorflow dataset?

I am writing a data pipeline to feed batches of time-series sequences and corresponding labels into an LSTM model which requires a 3D input shape.我正在编写一个数据管道,将批次的时间序列序列和相应的标签输入到 LSTM model 中,这需要 3D 输入形状。 I currently have the following:我目前有以下内容:

def split(window):
    return window[:-label_length], window[-label_length]

dataset = tf.data.Dataset.from_tensor_slices(data.sin)
dataset = dataset.window(input_length + label_length, shift=label_shift, stride=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(input_length + label_length))
dataset = dataset.map(split, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.cache()
dataset = dataset.shuffle(shuffle_buffer, seed=shuffle_seed, reshuffle_each_iteration=False)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

The resulting shape of for x, y in dataset.take(1): x.shape is (32, 20), where 32 is the batch size and 20 the length of the sequence, but I need a shape of (32, 20, 1), where the additional dimension denotes the feature. for x, y in dataset.take(1): x.shape为 (32, 20),其中 32 是批量大小,20 是序列长度,但我需要 (32, 20) 的形状, 1),其中附加维度表示特征。

My question is how I can reshape, ideally in the split function that is passed into the dataset.map function before caching the data?我的问题是如何重塑,理想情况下,在缓存数据之前传递到数据集的split dataset.map中。

That's easy.这很容易。 Do this in your split function在您的拆分 function 中执行此操作

def split(window):
    return window[:-label_length, tf.newaxis], window[-label_length, tf.newaxis, tf.newaxis]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM