From bc501637cdb329db681b439563cdae418f3fa897 Mon Sep 17 00:00:00 2001 From: Vasil Zlatanov Date: Wed, 6 Mar 2019 20:39:00 +0000 Subject: Revert "Add virtual_batch support" This reverts commit 740e1b0c6a02a7bec20008758373f0dd80baade4. --- dcgan.py | 36 +++++++++++------------------------- 1 file changed, 11 insertions(+), 25 deletions(-) (limited to 'dcgan.py') diff --git a/dcgan.py b/dcgan.py index 21afaac..347f61e 100644 --- a/dcgan.py +++ b/dcgan.py @@ -1,16 +1,11 @@ from __future__ import print_function, division -import tensorflow as keras - -import tensorflow as tf -from tensorflow.keras.datasets import mnist -from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout -from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D -from tensorflow.keras.layers import LeakyReLU -from tensorflow.keras.layers import UpSampling2D, Conv2D -from tensorflow.keras.models import Sequential, Model -from tensorflow.keras.optimizers import Adam - -from lib.virtual_batch import VirtualBatchNormalization +from keras.datasets import mnist +from keras.layers import Input, Dense, Reshape, Flatten, Dropout +from keras.layers import BatchNormalization, Activation, ZeroPadding2D +from keras.layers.advanced_activations import LeakyReLU +from keras.layers.convolutional import UpSampling2D, Conv2D +from keras.models import Sequential, Model +from keras.optimizers import Adam import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec @@ -22,7 +17,7 @@ import sys import numpy as np class DCGAN(): - def __init__(self, conv_layers = 1, virtual_batch_normalization=False): + def __init__(self, conv_layers = 1): # Input shape self.img_rows = 28 self.img_cols = 28 @@ -30,7 +25,6 @@ class DCGAN(): self.img_shape = (self.img_rows, self.img_cols, self.channels) self.latent_dim = 100 self.conv_layers = conv_layers - self.virtual_batch_normalization = virtual_batch_normalization optimizer = Adam(0.002, 0.5) @@ -68,21 +62,14 @@ class DCGAN(): for i in range(self.conv_layers): model.add(Conv2D(128, kernel_size=3, padding="same")) - if self.virtual_batch_normalization: - model.add(VirtualBatchNormalization()) - else: - model.add(BatchNormalization()) + model.add(BatchNormalization()) model.add(Activation("relu")) model.add(UpSampling2D()) for i in range(self.conv_layers): model.add(Conv2D(64, kernel_size=3, padding="same")) - if self.virtual_batch_normalization: - model.add(VirtualBatchNormalization()) - else: - model.add(BatchNormalization()) - + model.add(BatchNormalization()) model.add(Activation("relu")) model.add(Conv2D(self.channels, kernel_size=3, padding="same")) @@ -151,7 +138,6 @@ class DCGAN(): idx = np.random.randint(0, X_train.shape[0], batch_size) imgs = X_train[idx] - tf.keras.backend.get_session().run(tf.global_variables_initializer()) # Sample noise and generate a batch of new images noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) gen_imgs = self.generator.predict(noise) @@ -203,6 +189,6 @@ class DCGAN(): ''' if __name__ == '__main__': - dcgan = DCGAN(virtual_batch_normalization=True) + dcgan = DCGAN() dcgan.train(epochs=4000, batch_size=32, save_interval=50) ''' -- cgit v1.2.3-54-g00ecf