From 1fb7afb062fb50e0be0f32d3c8969e5ec1e72314 Mon Sep 17 00:00:00 2001 From: nunzip Date: Wed, 13 Mar 2019 17:31:21 +0000 Subject: Fix gd attempt --- cgan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cgan.py b/cgan.py index d579e33..b68e4ab 100755 --- a/cgan.py +++ b/cgan.py @@ -108,7 +108,7 @@ class CGAN(): return Model([img, label], validity) - def train(self, epochs, batch_size=128, sample_interval=50, graph=False, smooth_real=1, smooth_fake=0): + def train(self, epochs, batch_size=128, sample_interval=50, graph=False, smooth_real=1, smooth_fake=0, gdstep=1): # Load the dataset (X_train, y_train), (_, _) = mnist.load_data() @@ -153,7 +153,7 @@ class CGAN(): # Condition on labels sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1) # Train the generator - if epoch % 3 == 0 + if epoch % gdstep == 0: g_loss = self.combined.train_on_batch([noise, sampled_labels], valid) else: g_loss = 0 -- cgit v1.2.3