diff options
author | nunzip <np.scarh@gmail.com> | 2019-03-13 17:31:21 +0000 |
---|---|---|
committer | nunzip <np.scarh@gmail.com> | 2019-03-13 17:31:21 +0000 |
commit | 1fb7afb062fb50e0be0f32d3c8969e5ec1e72314 (patch) | |
tree | 320d306fc3f5c46db5945ebde194cf501492759c | |
parent | 79d666afdf6517ea15bfc9b882f7e4e77bff295b (diff) | |
download | e4-gan-1fb7afb062fb50e0be0f32d3c8969e5ec1e72314.tar.gz e4-gan-1fb7afb062fb50e0be0f32d3c8969e5ec1e72314.tar.bz2 e4-gan-1fb7afb062fb50e0be0f32d3c8969e5ec1e72314.zip |
Fix gd attempt
-rwxr-xr-x | cgan.py | 4 |
1 files changed, 2 insertions, 2 deletions
@@ -108,7 +108,7 @@ class CGAN(): return Model([img, label], validity) - def train(self, epochs, batch_size=128, sample_interval=50, graph=False, smooth_real=1, smooth_fake=0): + def train(self, epochs, batch_size=128, sample_interval=50, graph=False, smooth_real=1, smooth_fake=0, gdstep=1): # Load the dataset (X_train, y_train), (_, _) = mnist.load_data() @@ -153,7 +153,7 @@ class CGAN(): # Condition on labels sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1) # Train the generator - if epoch % 3 == 0 + if epoch % gdstep == 0: g_loss = self.combined.train_on_batch([noise, sampled_labels], valid) else: g_loss = 0 |