aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-03-13 17:25:54 +0000
committernunzip <np.scarh@gmail.com>2019-03-13 17:25:54 +0000
commit79d666afdf6517ea15bfc9b882f7e4e77bff295b (patch)
tree180715aa4b35fc6297177ec58d7ecdd6af3c879f
parent672bdd094082d5be99b3149269a00f94875d0698 (diff)
downloade4-gan-79d666afdf6517ea15bfc9b882f7e4e77bff295b.tar.gz
e4-gan-79d666afdf6517ea15bfc9b882f7e4e77bff295b.tar.bz2
e4-gan-79d666afdf6517ea15bfc9b882f7e4e77bff295b.zip
Try GD
-rwxr-xr-xcgan.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/cgan.py b/cgan.py
index a34a0e3..d579e33 100755
--- a/cgan.py
+++ b/cgan.py
@@ -141,6 +141,7 @@ class CGAN():
gen_imgs = self.generator.predict([noise, labels])
# Train the discriminator
+
d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid*smooth_real)
d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], valid*smooth_fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
@@ -152,7 +153,10 @@ class CGAN():
# Condition on labels
sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
# Train the generator
- g_loss = self.combined.train_on_batch([noise, sampled_labels], valid)
+ if epoch % 3 == 0
+ g_loss = self.combined.train_on_batch([noise, sampled_labels], valid)
+ else:
+ g_loss = 0
# Plot the progress
#print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))