diff options
| -rw-r--r-- | dcgan.py | 11 | 
1 files changed, 10 insertions, 1 deletions
@@ -112,7 +112,7 @@ class DCGAN():          return Model(img, validity) -    def train(self, epochs, batch_size=128, save_interval=50): +    def train(self, epochs, batch_size=128, save_interval=50, VBN=False):          # Load the dataset          (X_train, _), (_, _) = mnist.load_data() @@ -127,6 +127,7 @@ class DCGAN():          xaxis = np.arange(epochs)          loss = np.zeros((2,epochs)) +          for epoch in tqdm(range(epochs)):              # --------------------- @@ -137,6 +138,14 @@ class DCGAN():              idx = np.random.randint(0, X_train.shape[0], batch_size)              imgs = X_train[idx] +            if VBN: +                idx = np.random.randint(0, X_train.shape[0], batch_size) +                ref_imgs = X_train[idx] +                mu = np.mean(ref_imgs, axis=0)  +                sigma = np.var(ref_imgs, axis=0) +                sigma[sigma<1] = 1 +                img = np.divide(np.subtract(img, mu), sigma) +              # Sample noise and generate a batch of new images              noise = np.random.normal(0, 1, (batch_size, self.latent_dim))              gen_imgs = self.generator.predict(noise)  | 
