2017-12-25 14 views
1

Generator가 자동 인코딩이고 Discriminator가 Convolutional Neural Net 인 이진 출력이있는 Tensorflow가있는 GAN을 개발하고 싶습니다. 자동 코드 작성기와 CNN을 개발하는 데는 문제가 없지만, 제 생각에는 구성 요소 (Discriminator 및 Generator) 각각에 대해 1 에포크를 교육하고 이전주기의 결과 (가중치)를 유지하면서이주기를 1000 개 에포크 반복하십시오. 다음을 위해. 어떻게 이것을 조작 할 수 있습니까?Tensorflow에서 GAN Generator 및 Discriminator를 비동기 적으로 업데이트하는 방법은 무엇입니까?

답변

1

두 개의 작전 호출이있는 경우 train_step_generatortrain_step_discriminator (각각의, 예를 들어, 형태 각각에 대한 적절한 손실 tf.train.AdamOptimizer().minimize(loss)의), 다음 훈련 루프는 다음과 같은 구조와 유사한 구조이어야한다 :

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    for epoch in range(1000): 
     if epoch%2 == 0: # train discriminator on even epochs 
      for i in range(training_set_size/batch_size): 
       z_ = np.random.normal(0,1,batch_size) # this is the input to the generator 
       batch = get_next_batch(batch_size) 
       sess.run(train_step_discriminator,feed_dict={z:z_, x:batch}) 
     else: # train generator on odd epochs 
      for i in range(training_set_size/batch_size): 
       z_ = np.random.normal(0,1,batch_size) # this is the input to the generator 
       sess.run(train_step_generator,feed_dict={z:z_}) 

가중치는 반복 사이에 유지됩니다.

0

나는이 문제를 해결했다. 사실, 자동 인코딩 장치의 출력을 CNN의 입력으로 사용하여 GAN을 연결하고 가중치를 1 : 1의 비율로 업데이트하려고합니다. 나는 두 번째 루프가 시작될 때 생성기의 텐서 손실이 Discriminator에 의해 생성 된 마지막 손실 인 float로 대체 될 수 있도록 발전기와 discriminator의 손실을 구분하는 특별한주의가 필요하다는 것을 알아 냈습니다.

Here's 코드 여기

with tf.Session() as sess: 
sess.run(init) 
for i in range(1, num_steps+1): 

여기 생성기 훈련

batch_x, batch_y=next_batch(batch_size, x_train_noisy, x_train)   
    _, l = sess.run([optimizer, loss], feed_dict={X: batch_x.reshape(n,784), 
        Y:batch_y}) 
    if i % display_step == 0 or i == 1: 
     print('Epoch %i: Denoising Loss: %f' % (i, l)) 

발전기의 출력은 판별

output=sess.run([decoder_op],feed_dict={X: x_train}) 
    x_train2=np.array(output).reshape(n,784).astype(np.float64) 

위한 입력으로 사용될 여기서 Discriminator 훈련

batch_x2, batch_y2 = next_batch(batch_size, x_train2, y_train) 
    sess.run(train_op, feed_dict={X2: batch_x2.reshape(n,784), Y2: batch_y2, keep_prob: 0.8}) 
    if i % display_step == 0 or i == 1: 
     loss3, acc = sess.run([loss_op2, accuracy], feed_dict={X2: batch_x2, 
                  Y2: batch_y2, 
                  keep_prob: 1.0}) 
     print("Epoch " + str(i) + ", CNN Loss= " + \ 
       "{:.4f}".format(loss3) + ", Training Accuracy= " + "{:.3f}".format(acc)) 

비동기 업데이트 비율 1 조작화 할 수있다이 방법 : 1, 1 : 5, 5 : 1 (차별 : 발전기) 또는 임의의 다른 방법