diff options
-rwxr-xr-x | cgan.py | 4 |
1 files changed, 2 insertions, 2 deletions
@@ -108,7 +108,7 @@ class CGAN(): return Model([img, label], validity) - def train(self, epochs, batch_size=128, sample_interval=50, graph=False, smooth_real=1, smooth_fake=0): + def train(self, epochs, batch_size=128, sample_interval=50, graph=False, smooth_real=1, smooth_fake=0, gdstep=1): # Load the dataset (X_train, y_train), (_, _) = mnist.load_data() @@ -153,7 +153,7 @@ class CGAN(): # Condition on labels sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1) # Train the generator - if epoch % 3 == 0 + if epoch % gdstep == 0: g_loss = self.combined.train_on_batch([noise, sampled_labels], valid) else: g_loss = 0 |