簡體   English   中英

Tensorflow 2 開始訓練需要很長時間

[英]Starting training takes a very long time in Tensorflow 2

Tensorflow 2 需要大約 15 分鍾來制作其 static 圖表(或在第一次通過之前所做的任何事情)。 這之后的訓練時間是正常的,但顯然很難嘗試等待任何反饋的 15 分鍾。

生成器編碼器和鑒別器是帶有 GRU 單元的 RNN(未展開),位於 Keras model 中。

生成器解碼器的定義和調用如下:

class GeneratorDecoder(tf.keras.layers.Layer):
def __init__(self, feature_dim):
    super(GeneratorDecoder, self).__init__()
    self.cell = tf.keras.layers.GRUCell(
        GRUI_DIM, activation='tanh', recurrent_activation='sigmoid',
        dropout=DROPOUT, recurrent_dropout=DROPOUT)
    self.batch_normalization = tf.keras.layers.BatchNormalization()
    self.dense = tf.keras.layers.Dense(
        feature_dim, activation='tanh')

@tf.function
def __call__(self, z, timesteps, training):
    # z has shape (batch_size, features)
    outputs = []
    output, state = z, z
    for i in range(timesteps):
        output, state = self.cell(inputs=output, states=state,
                                  training=training)
        dense_output = self.dense(
            self.batch_normalization(output))
        outputs.append(dense_output)
    return outputs

這是我的訓練循環(mask_gt 和 missing_data 變量使用 tf.cast 進行轉換,應該已經是張量):

for it in tqdm(range(NO_ITERATIONS)):
   print(it)
   train_step()


@tf.function
def train_step():
    with tf.GradientTape(persistent=True) as tape:
        generator_output = generator(missing_data, training=True)
        imputed_data = get_imputed_data(missing_data, generator_output)
        mask_pred = discriminator(imputed_data)
        D_loss = discriminator.loss(mask_pred, mask_gt)
        G_loss = generator.loss(missing_data, mask_gt,
                                generator_output, mask_pred)
    gen_enc_grad = tape.gradient(
        G_loss, generator.encoder.trainable_variables)
    gen_dec_grad = tape.gradient(
        G_loss, generator.decoder.trainable_variables)
    disc_grad = tape.gradient(
        D_loss, discriminator.model.trainable_variables)
    del tape

    generator.optimizer.apply_gradients(
        zip(gen_enc_grad, generator.encoder.trainable_variables))
    generator.optimizer.apply_gradients(
        zip(gen_dec_grad, generator.decoder.trainable_variables))
    discriminator.optimizer.apply_gradients(
        zip(disc_grad, discriminator.model.trainable_variables))

注意“0”是在幾秒內打印出來的,所以慢的部分肯定不會更早。 這是調用的 get_imputed_data function:

def get_imputed_data(incomplete_series, generator_output):
    return tf.where(tf.math.is_nan(incomplete_series), generator_output, incomplete_series)

感謝您的任何回答。 希望我提供了足夠的代碼來說明問題所在:這是我閱讀至少五年后第一次在這里發帖:)

我使用 Python 3.6 和 Tensorflow 2.1。

該問題已通過刪除 tf.function 裝飾器來解決,該裝飾器用於生成器和鑒別器的調用函數。 我在兩個 tf.function 修飾函數中使用了一個全局 python 標量(迭代號)。 這導致每次都創建一個新圖表(請參閱tf.function 文檔中的注意事項)。

解決方案是刪除使用的 python 變量或將它們轉換為 tensorflow 變量。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM