aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-03-13 18:52:18 +0000
committernunzip <np.scarh@gmail.com>2019-03-13 18:52:18 +0000
commit9945d9fe431f0b01c528b311acb685bebd99ab48 (patch)
treed7644fbacfac2416266304582c9fd2932ffbb337
parent1fb7afb062fb50e0be0f32d3c8969e5ec1e72314 (diff)
downloade4-gan-9945d9fe431f0b01c528b311acb685bebd99ab48.tar.gz
e4-gan-9945d9fe431f0b01c528b311acb685bebd99ab48.tar.bz2
e4-gan-9945d9fe431f0b01c528b311acb685bebd99ab48.zip
Introduce gd balancing in DCGAN
-rw-r--r--dcgan.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/dcgan.py b/dcgan.py
index 7844843..4317994 100644
--- a/dcgan.py
+++ b/dcgan.py
@@ -113,7 +113,7 @@ class DCGAN():
return Model(img, validity)
- def train(self, epochs, batch_size=128, save_interval=50, VBN=False):
+ def train(self, epochs, batch_size=128, save_interval=50, VBN=False, gdstep=1):
# Load the dataset
(X_train, _), (_, _) = mnist.load_data()
@@ -153,7 +153,10 @@ class DCGAN():
# ---------------------
# Train the generator (wants discriminator to mistake images as real)
- g_loss = self.combined.train_on_batch(noise, valid)
+ if epoch % gdstep == 0:
+ g_loss = self.combined.train_on_batch(noise, valid)
+ else:
+ g_loss = 0
# Plot the progress
#print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))