diff options
Diffstat (limited to 'dcgan.py')
-rw-r--r-- | dcgan.py | 34 |
1 files changed, 24 insertions, 10 deletions
@@ -1,11 +1,15 @@ from __future__ import print_function, division -from keras.datasets import mnist -from keras.layers import Input, Dense, Reshape, Flatten, Dropout -from keras.layers import BatchNormalization, Activation, ZeroPadding2D -from keras.layers.advanced_activations import LeakyReLU -from keras.layers.convolutional import UpSampling2D, Conv2D -from keras.models import Sequential, Model -from keras.optimizers import Adam +import tensorflow as keras + +from tensorflow.keras.datasets import mnist +from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout +from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D +from tensorflow.keras.layers.advanced_activations import LeakyReLU +from tensorflow.keras.layers.convolutional import UpSampling2D, Conv2D +from tensorflow.keras.models import Sequential, Model +from tensorflow.keras.optimizers import Adam + +from lib/virtual_batch import VirtualBatchNormalization import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec @@ -17,7 +21,7 @@ import sys import numpy as np class DCGAN(): - def __init__(self, conv_layers = 1): + def __init__(self, conv_layers = 1, virtual_batch_normalization=False): # Input shape self.img_rows = 28 self.img_cols = 28 @@ -25,6 +29,7 @@ class DCGAN(): self.img_shape = (self.img_rows, self.img_cols, self.channels) self.latent_dim = 100 self.conv_layers = conv_layers + self.virtual_batch_normalization = virtual_batch_normalization optimizer = Adam(0.002, 0.5) @@ -62,14 +67,21 @@ class DCGAN(): for i in range(self.conv_layers): model.add(Conv2D(128, kernel_size=3, padding="same")) - model.add(BatchNormalization()) + if self.virtual_batch_normalization: + model.add(VirtualBatchNormalization()) + else: + model.add(BatchNormalization()) model.add(Activation("relu")) model.add(UpSampling2D()) for i in range(self.conv_layers): model.add(Conv2D(64, kernel_size=3, padding="same")) - model.add(BatchNormalization()) + if self.virtual_batch_normalization: + model.add(VirtualBatchNormalization()) + else: + model.add(BatchNormalization()) + model.add(Activation("relu")) model.add(Conv2D(self.channels, kernel_size=3, padding="same")) @@ -137,6 +149,8 @@ class DCGAN(): idx = np.random.randint(0, X_train.shape[0], batch_size) imgs = X_train[idx] + tf.keras.backend.get_session().run(tf.global_variables_initializer()) + # Sample noise and generate a batch of new images noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) gen_imgs = self.generator.predict(noise) |