簡體   English   中英

張量流中實現的GAN中的'reuse'標志的目的是什么?

[英]What is the purpose of the 'reuse' flag in GAN implemented in tensorflow?

我正在閱讀GAN教程,並且注意到“重用”標志的使用,但我不太了解它們在做什么或在做什么。 如果查看下面的代碼,您將看到在每個變量范圍初始化中都使用了reuse

(我嘗試查看文檔,但仍然不清楚: https : //www.tensorflow.org/versions/r0.12/how_tos/variable_scope/

def discriminator(images, reuse=False):
    """
    Create the discriminator network
    """
    alpha = 0.2

    with tf.variable_scope('discriminator', reuse=reuse):
        # using 4 layer network as in DCGAN Paper

        # Conv 1
        conv1 = tf.layers.conv2d(images, 64, 5, 2, 'SAME')
        lrelu1 = tf.maximum(alpha * conv1, conv1)

        # Conv 2
        conv2 = tf.layers.conv2d(lrelu1, 128, 5, 2, 'SAME')
        batch_norm2 = tf.layers.batch_normalization(conv2, training=True)
        lrelu2 = tf.maximum(alpha * batch_norm2, batch_norm2)

        # Conv 3
        conv3 = tf.layers.conv2d(lrelu2, 256, 5, 1, 'SAME')
        batch_norm3 = tf.layers.batch_normalization(conv3, training=True)
        lrelu3 = tf.maximum(alpha * batch_norm3, batch_norm3)

        # Flatten
        flat = tf.reshape(lrelu3, (-1, 4*4*256))

        # Logits
        logits = tf.layers.dense(flat, 1)

        # Output
        out = tf.sigmoid(logits)

        return out, logits
def generator(z, out_channel_dim, is_train=True):
    """
    Create the generator network
    """
    alpha = 0.2

    with tf.variable_scope('generator', reuse=False if is_train==True else True):
        # First fully connected layer
        x_1 = tf.layers.dense(z, 2*2*512)

        # Reshape it to start the convolutional stack
        deconv_2 = tf.reshape(x_1, (-1, 2, 2, 512))
        batch_norm2 = tf.layers.batch_normalization(deconv_2, training=is_train)
        lrelu2 = tf.maximum(alpha * batch_norm2, batch_norm2)


        # Deconv 1
        deconv3 = tf.layers.conv2d_transpose(lrelu2, 256, 5, 2, padding='VALID')
        batch_norm3 = tf.layers.batch_normalization(deconv3, training=is_train)
        lrelu3 = tf.maximum(alpha * batch_norm3, batch_norm3)



        # Deconv 2
        deconv4 = tf.layers.conv2d_transpose(lrelu3, 128, 5, 2, padding='SAME')
        batch_norm4 = tf.layers.batch_normalization(deconv4, training=is_train)
        lrelu4 = tf.maximum(alpha * batch_norm4, batch_norm4)


        #Deconv 3
        deconv5 = tf.layers.conv2d_transpose(lrelu4, 64, 5, 2, padding='SAME')
        batch_norm5 = tf.layers.batch_normalization(deconv5, training=is_train)
        lrelu5 = tf.maximum(alpha * batch_norm5, batch_norm5)



        # Output layer
        logits = tf.layers.conv2d_transpose(lrelu5, out_channel_dim, 5, 2, padding='SAME')
        out = tf.tanh(logits)

        return out

謝謝。

對於生成器,我們將對其進行訓練,同時在訓練過程中和訓練后還要從中進行采樣。 鑒別器將需要在假輸入圖像和真實輸入圖像之間共享變量。 因此,我們可以使用tf.variable_scope的reuse關鍵字來告訴TensorFlow重用變量,而不是在再次構建圖形時創建新變量。

然后是鑒別符。 我們將構建其中的兩個,一個用於真實數據,另一個用於虛假數據。 由於我們希望真實和偽數據的權重都相同,因此我們需要重用變量。 對於虛假數據,我們從生成器獲取它為g_model。 因此,真正的數據鑒別符是鑒別符(input_real),而假鑒別符是鑒別符(g_model,reuse = True)。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM