[英]What is the purpose of the 'reuse' flag in GAN implemented in tensorflow?
我正在閱讀GAN教程,並且注意到“重用”標志的使用,但我不太了解它們在做什么或在做什么。 如果查看下面的代碼,您將看到在每個變量范圍初始化中都使用了reuse
。
(我嘗試查看文檔,但仍然不清楚: https : //www.tensorflow.org/versions/r0.12/how_tos/variable_scope/ )
def discriminator(images, reuse=False):
"""
Create the discriminator network
"""
alpha = 0.2
with tf.variable_scope('discriminator', reuse=reuse):
# using 4 layer network as in DCGAN Paper
# Conv 1
conv1 = tf.layers.conv2d(images, 64, 5, 2, 'SAME')
lrelu1 = tf.maximum(alpha * conv1, conv1)
# Conv 2
conv2 = tf.layers.conv2d(lrelu1, 128, 5, 2, 'SAME')
batch_norm2 = tf.layers.batch_normalization(conv2, training=True)
lrelu2 = tf.maximum(alpha * batch_norm2, batch_norm2)
# Conv 3
conv3 = tf.layers.conv2d(lrelu2, 256, 5, 1, 'SAME')
batch_norm3 = tf.layers.batch_normalization(conv3, training=True)
lrelu3 = tf.maximum(alpha * batch_norm3, batch_norm3)
# Flatten
flat = tf.reshape(lrelu3, (-1, 4*4*256))
# Logits
logits = tf.layers.dense(flat, 1)
# Output
out = tf.sigmoid(logits)
return out, logits
def generator(z, out_channel_dim, is_train=True):
"""
Create the generator network
"""
alpha = 0.2
with tf.variable_scope('generator', reuse=False if is_train==True else True):
# First fully connected layer
x_1 = tf.layers.dense(z, 2*2*512)
# Reshape it to start the convolutional stack
deconv_2 = tf.reshape(x_1, (-1, 2, 2, 512))
batch_norm2 = tf.layers.batch_normalization(deconv_2, training=is_train)
lrelu2 = tf.maximum(alpha * batch_norm2, batch_norm2)
# Deconv 1
deconv3 = tf.layers.conv2d_transpose(lrelu2, 256, 5, 2, padding='VALID')
batch_norm3 = tf.layers.batch_normalization(deconv3, training=is_train)
lrelu3 = tf.maximum(alpha * batch_norm3, batch_norm3)
# Deconv 2
deconv4 = tf.layers.conv2d_transpose(lrelu3, 128, 5, 2, padding='SAME')
batch_norm4 = tf.layers.batch_normalization(deconv4, training=is_train)
lrelu4 = tf.maximum(alpha * batch_norm4, batch_norm4)
#Deconv 3
deconv5 = tf.layers.conv2d_transpose(lrelu4, 64, 5, 2, padding='SAME')
batch_norm5 = tf.layers.batch_normalization(deconv5, training=is_train)
lrelu5 = tf.maximum(alpha * batch_norm5, batch_norm5)
# Output layer
logits = tf.layers.conv2d_transpose(lrelu5, out_channel_dim, 5, 2, padding='SAME')
out = tf.tanh(logits)
return out
謝謝。
對於生成器,我們將對其進行訓練,同時在訓練過程中和訓練后還要從中進行采樣。 鑒別器將需要在假輸入圖像和真實輸入圖像之間共享變量。 因此,我們可以使用tf.variable_scope的reuse關鍵字來告訴TensorFlow重用變量,而不是在再次構建圖形時創建新變量。
然后是鑒別符。 我們將構建其中的兩個,一個用於真實數據,另一個用於虛假數據。 由於我們希望真實和偽數據的權重都相同,因此我們需要重用變量。 對於虛假數據,我們從生成器獲取它為g_model。 因此,真正的數據鑒別符是鑒別符(input_real),而假鑒別符是鑒別符(g_model,reuse = True)。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.