Generator가 자동 인코딩이고 Discriminator가 Convolutional Neural Net 인 이진 출력이있는 Tensorflow가있는 GAN을 개발하고 싶습니다. 자동 코드 작성기와 CNN을 개발하는 데는 문제가 없지만, 제 생각에는 구성 요소 (Discriminator 및 Generator) 각각에 대해 1 에포크를 교육하고 이전주기의 결과 (가중치)를 유지하면서이주기를 1000 개 에포크 반복하십시오. 다음을 위해. 어떻게 이것을 조작 할 수 있습니까?Tensorflow에서 GAN Generator 및 Discriminator를 비동기 적으로 업데이트하는 방법은 무엇입니까?
1
A
답변
1
두 개의 작전 호출이있는 경우 train_step_generator
및 train_step_discriminator
(각각의, 예를 들어, 형태 각각에 대한 적절한 손실 tf.train.AdamOptimizer().minimize(loss)
의), 다음 훈련 루프는 다음과 같은 구조와 유사한 구조이어야한다 :
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(1000):
if epoch%2 == 0: # train discriminator on even epochs
for i in range(training_set_size/batch_size):
z_ = np.random.normal(0,1,batch_size) # this is the input to the generator
batch = get_next_batch(batch_size)
sess.run(train_step_discriminator,feed_dict={z:z_, x:batch})
else: # train generator on odd epochs
for i in range(training_set_size/batch_size):
z_ = np.random.normal(0,1,batch_size) # this is the input to the generator
sess.run(train_step_generator,feed_dict={z:z_})
가중치는 반복 사이에 유지됩니다.
0
나는이 문제를 해결했다. 사실, 자동 인코딩 장치의 출력을 CNN의 입력으로 사용하여 GAN을 연결하고 가중치를 1 : 1의 비율로 업데이트하려고합니다. 나는 두 번째 루프가 시작될 때 생성기의 텐서 손실이 Discriminator에 의해 생성 된 마지막 손실 인 float로 대체 될 수 있도록 발전기와 discriminator의 손실을 구분하는 특별한주의가 필요하다는 것을 알아 냈습니다.
Here's 코드 여기
with tf.Session() as sess:
sess.run(init)
for i in range(1, num_steps+1):
여기 생성기 훈련
batch_x, batch_y=next_batch(batch_size, x_train_noisy, x_train)
_, l = sess.run([optimizer, loss], feed_dict={X: batch_x.reshape(n,784),
Y:batch_y})
if i % display_step == 0 or i == 1:
print('Epoch %i: Denoising Loss: %f' % (i, l))
발전기의 출력은 판별
output=sess.run([decoder_op],feed_dict={X: x_train})
x_train2=np.array(output).reshape(n,784).astype(np.float64)
위한 입력으로 사용될 여기서 Discriminator 훈련
batch_x2, batch_y2 = next_batch(batch_size, x_train2, y_train)
sess.run(train_op, feed_dict={X2: batch_x2.reshape(n,784), Y2: batch_y2, keep_prob: 0.8})
if i % display_step == 0 or i == 1:
loss3, acc = sess.run([loss_op2, accuracy], feed_dict={X2: batch_x2,
Y2: batch_y2,
keep_prob: 1.0})
print("Epoch " + str(i) + ", CNN Loss= " + \
"{:.4f}".format(loss3) + ", Training Accuracy= " + "{:.3f}".format(acc))
비동기 업데이트 비율 1 조작화 할 수있다이 방법 : 1, 1 : 5, 5 : 1 (차별 : 발전기) 또는 임의의 다른 방법