diff options
author | nunzip <np.scarh@gmail.com> | 2019-03-13 18:52:18 +0000 |
---|---|---|
committer | nunzip <np.scarh@gmail.com> | 2019-03-13 18:52:18 +0000 |
commit | 9945d9fe431f0b01c528b311acb685bebd99ab48 (patch) | |
tree | d7644fbacfac2416266304582c9fd2932ffbb337 | |
parent | 1fb7afb062fb50e0be0f32d3c8969e5ec1e72314 (diff) | |
download | e4-gan-9945d9fe431f0b01c528b311acb685bebd99ab48.tar.gz e4-gan-9945d9fe431f0b01c528b311acb685bebd99ab48.tar.bz2 e4-gan-9945d9fe431f0b01c528b311acb685bebd99ab48.zip |
Introduce gd balancing in DCGAN
-rw-r--r-- | dcgan.py | 7 |
1 files changed, 5 insertions, 2 deletions
@@ -113,7 +113,7 @@ class DCGAN(): return Model(img, validity) - def train(self, epochs, batch_size=128, save_interval=50, VBN=False): + def train(self, epochs, batch_size=128, save_interval=50, VBN=False, gdstep=1): # Load the dataset (X_train, _), (_, _) = mnist.load_data() @@ -153,7 +153,10 @@ class DCGAN(): # --------------------- # Train the generator (wants discriminator to mistake images as real) - g_loss = self.combined.train_on_batch(noise, valid) + if epoch % gdstep == 0: + g_loss = self.combined.train_on_batch(noise, valid) + else: + g_loss = 0 # Plot the progress #print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)) |