나는 GAN 튜토리얼을 통해 갈 것이고 '재사용'플래그를 사용했음을 눈치 챘다. 아래의 코드를 살펴보면 reuse
이 각 변수 범위 초기화 내에서 사용되는 것을 볼 수 있습니다. TAN에서 구현 된 GAN의 '재사용'플래그의 목적은 무엇입니까?
def discriminator(images, reuse=False):
"""
Create the discriminator network
"""
alpha = 0.2
with tf.variable_scope('discriminator', reuse=reuse):
# using 4 layer network as in DCGAN Paper
# Conv 1
conv1 = tf.layers.conv2d(images, 64, 5, 2, 'SAME')
lrelu1 = tf.maximum(alpha * conv1, conv1)
# Conv 2
conv2 = tf.layers.conv2d(lrelu1, 128, 5, 2, 'SAME')
batch_norm2 = tf.layers.batch_normalization(conv2, training=True)
lrelu2 = tf.maximum(alpha * batch_norm2, batch_norm2)
# Conv 3
conv3 = tf.layers.conv2d(lrelu2, 256, 5, 1, 'SAME')
batch_norm3 = tf.layers.batch_normalization(conv3, training=True)
lrelu3 = tf.maximum(alpha * batch_norm3, batch_norm3)
# Flatten
flat = tf.reshape(lrelu3, (-1, 4*4*256))
# Logits
logits = tf.layers.dense(flat, 1)
# Output
out = tf.sigmoid(logits)
return out, logits
def generator(z, out_channel_dim, is_train=True):
"""
Create the generator network
"""
alpha = 0.2
with tf.variable_scope('generator', reuse=False if is_train==True else True):
# First fully connected layer
x_1 = tf.layers.dense(z, 2*2*512)
# Reshape it to start the convolutional stack
deconv_2 = tf.reshape(x_1, (-1, 2, 2, 512))
batch_norm2 = tf.layers.batch_normalization(deconv_2, training=is_train)
lrelu2 = tf.maximum(alpha * batch_norm2, batch_norm2)
# Deconv 1
deconv3 = tf.layers.conv2d_transpose(lrelu2, 256, 5, 2, padding='VALID')
batch_norm3 = tf.layers.batch_normalization(deconv3, training=is_train)
lrelu3 = tf.maximum(alpha * batch_norm3, batch_norm3)
# Deconv 2
deconv4 = tf.layers.conv2d_transpose(lrelu3, 128, 5, 2, padding='SAME')
batch_norm4 = tf.layers.batch_normalization(deconv4, training=is_train)
lrelu4 = tf.maximum(alpha * batch_norm4, batch_norm4)
#Deconv 3
deconv5 = tf.layers.conv2d_transpose(lrelu4, 64, 5, 2, padding='SAME')
batch_norm5 = tf.layers.batch_normalization(deconv5, training=is_train)
lrelu5 = tf.maximum(alpha * batch_norm5, batch_norm5)
# Output layer
logits = tf.layers.conv2d_transpose(lrelu5, out_channel_dim, 5, 2, padding='SAME')
out = tf.tanh(logits)
return out
감사합니다.
감사합니다. 각 판별자가 자신의 그래프를 가지고 있다는 것을 의미합니까? 하나의 그래프라면 공유 할 필요가 없다고 가정하고 있습니까? – Moondra
@Moondra 그래프는 하나이지만 값은 동일해야합니다 ... 일반적인 피드 - 포워드를 구현하는 동안에도 그래프는 동일하게 유지됩니다 ... 변화하는 그래프의 값 – Jai