2017-12-06 11 views
0

나는 GAN 튜토리얼을 통해 갈 것이고 '재사용'플래그를 사용했음을 눈치 챘다. 아래의 코드를 살펴보면 reuse이 각 변수 범위 초기화 내에서 사용되는 것을 볼 수 있습니다. TAN에서 구현 된 GAN의 '재사용'플래그의 목적은 무엇입니까?

는 (I는 명확하지 여전히 문서에서 찾고 있지만 시도 : https://www.tensorflow.org/versions/r0.12/how_tos/variable_scope/)
def discriminator(images, reuse=False): 
    """ 
    Create the discriminator network 
    """ 
    alpha = 0.2 

    with tf.variable_scope('discriminator', reuse=reuse): 
     # using 4 layer network as in DCGAN Paper 

     # Conv 1 
     conv1 = tf.layers.conv2d(images, 64, 5, 2, 'SAME') 
     lrelu1 = tf.maximum(alpha * conv1, conv1) 

     # Conv 2 
     conv2 = tf.layers.conv2d(lrelu1, 128, 5, 2, 'SAME') 
     batch_norm2 = tf.layers.batch_normalization(conv2, training=True) 
     lrelu2 = tf.maximum(alpha * batch_norm2, batch_norm2) 

     # Conv 3 
     conv3 = tf.layers.conv2d(lrelu2, 256, 5, 1, 'SAME') 
     batch_norm3 = tf.layers.batch_normalization(conv3, training=True) 
     lrelu3 = tf.maximum(alpha * batch_norm3, batch_norm3) 

     # Flatten 
     flat = tf.reshape(lrelu3, (-1, 4*4*256)) 

     # Logits 
     logits = tf.layers.dense(flat, 1) 

     # Output 
     out = tf.sigmoid(logits) 

     return out, logits 
def generator(z, out_channel_dim, is_train=True): 
    """ 
    Create the generator network 
    """ 
    alpha = 0.2 

    with tf.variable_scope('generator', reuse=False if is_train==True else True): 
     # First fully connected layer 
     x_1 = tf.layers.dense(z, 2*2*512) 

     # Reshape it to start the convolutional stack 
     deconv_2 = tf.reshape(x_1, (-1, 2, 2, 512)) 
     batch_norm2 = tf.layers.batch_normalization(deconv_2, training=is_train) 
     lrelu2 = tf.maximum(alpha * batch_norm2, batch_norm2) 


     # Deconv 1 
     deconv3 = tf.layers.conv2d_transpose(lrelu2, 256, 5, 2, padding='VALID') 
     batch_norm3 = tf.layers.batch_normalization(deconv3, training=is_train) 
     lrelu3 = tf.maximum(alpha * batch_norm3, batch_norm3) 



     # Deconv 2 
     deconv4 = tf.layers.conv2d_transpose(lrelu3, 128, 5, 2, padding='SAME') 
     batch_norm4 = tf.layers.batch_normalization(deconv4, training=is_train) 
     lrelu4 = tf.maximum(alpha * batch_norm4, batch_norm4) 


     #Deconv 3 
     deconv5 = tf.layers.conv2d_transpose(lrelu4, 64, 5, 2, padding='SAME') 
     batch_norm5 = tf.layers.batch_normalization(deconv5, training=is_train) 
     lrelu5 = tf.maximum(alpha * batch_norm5, batch_norm5) 



     # Output layer 
     logits = tf.layers.conv2d_transpose(lrelu5, out_channel_dim, 5, 2, padding='SAME') 
     out = tf.tanh(logits) 

     return out 

감사합니다.

답변

2

발전기의 경우, 우리는 그것을 훈련 할 것이지만, 우리가 훈련을하고 훈련을 할 때 그것으로부터 샘플을 얻을 것입니다. 판별 기는 가짜 및 실제 입력 이미지간에 변수를 공유해야합니다. 따라서 tf.variable_scope에 reuse 키워드를 사용하여 TensorFlow에게 그래프를 다시 작성하는 경우 새 변수를 작성하는 대신 변수를 재사용하도록 지시 할 수 있습니다.

그런 다음 판별 자. 실제 데이터 용과 가짜 데이터 용으로 두 가지를 빌드 할 것입니다. 실제 데이터와 가짜 데이터에서 가중치를 동일하게 유지하려면 변수를 다시 사용해야합니다. 가짜 데이터의 경우 생성기에서 g_model로 가져옵니다. 따라서 실제 데이터 판별 기는 판별 자 (input_real)이고 가짜 판별자는 판별 자 (g_model, reuse = True)입니다.

+0

감사합니다. 각 판별자가 자신의 그래프를 가지고 있다는 것을 의미합니까? 하나의 그래프라면 공유 할 필요가 없다고 가정하고 있습니까? – Moondra

+0

@Moondra 그래프는 하나이지만 값은 동일해야합니다 ... 일반적인 피드 - 포워드를 구현하는 동안에도 그래프는 동일하게 유지됩니다 ... 변화하는 그래프의 값 – Jai