aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xcgan.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/cgan.py b/cgan.py
index d579e33..b68e4ab 100755
--- a/cgan.py
+++ b/cgan.py
@@ -108,7 +108,7 @@ class CGAN():
return Model([img, label], validity)
- def train(self, epochs, batch_size=128, sample_interval=50, graph=False, smooth_real=1, smooth_fake=0):
+ def train(self, epochs, batch_size=128, sample_interval=50, graph=False, smooth_real=1, smooth_fake=0, gdstep=1):
# Load the dataset
(X_train, y_train), (_, _) = mnist.load_data()
@@ -153,7 +153,7 @@ class CGAN():
# Condition on labels
sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
# Train the generator
- if epoch % 3 == 0
+ if epoch % gdstep == 0:
g_loss = self.combined.train_on_batch([noise, sampled_labels], valid)
else:
g_loss = 0