diff options
| -rwxr-xr-x | cgan.py | 4 | 
1 files changed, 2 insertions, 2 deletions
| @@ -108,7 +108,7 @@ class CGAN():          return Model([img, label], validity) -    def train(self, epochs, batch_size=128, sample_interval=50, graph=False, smooth_real=1, smooth_fake=0): +    def train(self, epochs, batch_size=128, sample_interval=50, graph=False, smooth_real=1, smooth_fake=0, gdstep=1):          # Load the dataset          (X_train, y_train), (_, _) = mnist.load_data() @@ -153,7 +153,7 @@ class CGAN():              # Condition on labels              sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)              # Train the generator -            if epoch % 3 == 0 +            if epoch % gdstep == 0:                  g_loss = self.combined.train_on_batch([noise, sampled_labels], valid)              else:                  g_loss = 0 | 
