簡體   English   中英

如何重塑 Tensorflow 數據集中的數據?

[英]How to reshape data in Tensorflow dataset?

我正在編寫一個數據管道,將批次的時間序列序列和相應的標簽輸入到 LSTM model 中,這需要 3D 輸入形狀。 我目前有以下內容:

def split(window):
    return window[:-label_length], window[-label_length]

dataset = tf.data.Dataset.from_tensor_slices(data.sin)
dataset = dataset.window(input_length + label_length, shift=label_shift, stride=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(input_length + label_length))
dataset = dataset.map(split, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.cache()
dataset = dataset.shuffle(shuffle_buffer, seed=shuffle_seed, reshuffle_each_iteration=False)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

for x, y in dataset.take(1): x.shape為 (32, 20),其中 32 是批量大小,20 是序列長度,但我需要 (32, 20) 的形狀, 1),其中附加維度表示特征。

我的問題是如何重塑,理想情況下,在緩存數據之前傳遞到數據集的split dataset.map中。

這很容易。 在您的拆分 function 中執行此操作

def split(window):
    return window[:-label_length, tf.newaxis], window[-label_length, tf.newaxis, tf.newaxis]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM