[英]How to reshape data in Tensorflow dataset?
我正在編寫一個數據管道,將批次的時間序列序列和相應的標簽輸入到 LSTM model 中,這需要 3D 輸入形狀。 我目前有以下內容:
def split(window):
return window[:-label_length], window[-label_length]
dataset = tf.data.Dataset.from_tensor_slices(data.sin)
dataset = dataset.window(input_length + label_length, shift=label_shift, stride=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(input_length + label_length))
dataset = dataset.map(split, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.cache()
dataset = dataset.shuffle(shuffle_buffer, seed=shuffle_seed, reshuffle_each_iteration=False)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
for x, y in dataset.take(1): x.shape
為 (32, 20),其中 32 是批量大小,20 是序列長度,但我需要 (32, 20) 的形狀, 1),其中附加維度表示特征。
我的問題是如何重塑,理想情況下,在緩存數據之前傳遞到數據集的split
dataset.map
中。
這很容易。 在您的拆分 function 中執行此操作
def split(window):
return window[:-label_length, tf.newaxis], window[-label_length, tf.newaxis, tf.newaxis]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.