簡體   English   中英

如何在多輸出 model 訓練期間加載數據,而無需在 Keras 中進行迭代?

[英]How to load data during training of a multi-output model without iteration in Keras?

I have a Keras model with 1 input and 2 outputs in TensorFlow 2. When calling model.fit I want to pass dataset as x=train_dataset and call model.fit once. train_datasettf.data.Dataset.from_generator生成:x1, y1, y2。

我可以進行培訓的唯一方法是:

for x1, y1,y2 in train_dataset:
    model.fit(x=x1, y=[y1,y2],...)

如何告訴 TensorFlow 解包變量並在沒有顯式for循環的情況下進行訓練? 使用for循環使許多事情變得不那么實用,以及train_on_batch的使用。

如果我想運行model.fit(train_dataset, ...) function 不明白xy是什么,甚至 model 的定義如下:

model = Model(name ='Joined_Model',inputs=self.x, outputs=[self.network.y1, self.network.y2])

它會拋出一個錯誤,即在獲得 1 時期望 2 個目標,即使數據集有 3 個變量,也可以在循環中迭代。

數據集和小批量生成為:

def dataset_joined(self, n_epochs, buffer_size=32):
    dataset = tf.data.Dataset.from_generator(
        self.mbatch_gen_joined,
        (tf.float32, tf.float32,tf.int32),
        (tf.TensorShape([None, None, self.n_feat]),
            tf.TensorShape([None, None, self.n_feat]),
            tf.TensorShape([None, None])),
        [tf.constant(n_epochs)]
        )
    dataset = dataset.prefetch(buffer_size)
    return dataset

    def mbatch_gen_joined(self, n_epochs):
    for _ in range(n_epochs):
        random.shuffle(self.train_s_list)
        start_idx, end_idx = 0, self.mbatch_size
        for _ in range(self.n_iter):
            s_mbatch_list = self.train_s_list[start_idx:end_idx]
            d_mbatch_list = random.sample(self.train_d_list, end_idx-start_idx)
            s_mbatch, d_mbatch, s_mbatch_len, d_mbatch_len, snr_mbatch, label_mbatch, _ = \
                self.wav_batch(s_mbatch_list, d_mbatch_list)
            x_STMS_mbatch, xi_bar_mbatch, _ = \
                self.training_example(s_mbatch, d_mbatch, s_mbatch_len,
                d_mbatch_len, snr_mbatch)
            #seq_mask_mbatch = tf.cast(tf.sequence_mask(n_frames_mbatch), tf.float32)
            start_idx += self.mbatch_size; end_idx += self.mbatch_size
            if end_idx > self.n_examples: end_idx = self.n_examples

            yield x_STMS_mbatch, xi_bar_mbatch, label_mbatch

Keras 模型期望 Python 生成器或tf.data.Dataset對象將輸入數據作為元組提供,格式為(input_data, target_data) (或(input_data, target_data, sample_weights) )。 如果 model 具有多個輸入/輸出層,則每個input_datatarget_data都可以而且應該是一個列表/元組。 因此,在您的代碼中,生成的數據也應該與這種預期格式兼容:

yield x_STMS_mbatch, (xi_bar_mbatch, label_mbatch)  # <- the second element is a tuple itself

此外,在傳遞給from_generator方法的 arguments 中也應考慮這一點:

dataset = tf.data.Dataset.from_generator(
    self.mbatch_gen_joined,
    output_types=(
        tf.float32,
        (tf.float32, tf.int32)
    ),
    output_shapes=(
        tf.TensorShape([None, None, self.n_feat]),
        (
            tf.TensorShape([None, None, self.n_feat]),
            tf.TensorShape([None, None])
        )
    ),
    args=(tf.constant(n_epochs),)
)

使用yield(x1, [y1,y2])所以 model.fit 將了解您的生成器 output。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM