aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-03-13 17:31:21 +0000
committernunzip <np.scarh@gmail.com>2019-03-13 17:31:21 +0000
commit1fb7afb062fb50e0be0f32d3c8969e5ec1e72314 (patch)
tree320d306fc3f5c46db5945ebde194cf501492759c
parent79d666afdf6517ea15bfc9b882f7e4e77bff295b (diff)
downloade4-gan-1fb7afb062fb50e0be0f32d3c8969e5ec1e72314.tar.gz
e4-gan-1fb7afb062fb50e0be0f32d3c8969e5ec1e72314.tar.bz2
e4-gan-1fb7afb062fb50e0be0f32d3c8969e5ec1e72314.zip
Fix gd attempt
-rwxr-xr-xcgan.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/cgan.py b/cgan.py
index d579e33..b68e4ab 100755
--- a/cgan.py
+++ b/cgan.py
@@ -108,7 +108,7 @@ class CGAN():
return Model([img, label], validity)
- def train(self, epochs, batch_size=128, sample_interval=50, graph=False, smooth_real=1, smooth_fake=0):
+ def train(self, epochs, batch_size=128, sample_interval=50, graph=False, smooth_real=1, smooth_fake=0, gdstep=1):
# Load the dataset
(X_train, y_train), (_, _) = mnist.load_data()
@@ -153,7 +153,7 @@ class CGAN():
# Condition on labels
sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
# Train the generator
- if epoch % 3 == 0
+ if epoch % gdstep == 0:
g_loss = self.combined.train_on_batch([noise, sampled_labels], valid)
else:
g_loss = 0