aboutsummaryrefslogtreecommitdiff
path: root/dcgan.py
diff options
context:
space:
mode:
Diffstat (limited to 'dcgan.py')
-rw-r--r--dcgan.py44
1 files changed, 25 insertions, 19 deletions
diff --git a/dcgan.py b/dcgan.py
index 0d0ff12..21afaac 100644
--- a/dcgan.py
+++ b/dcgan.py
@@ -1,11 +1,16 @@
from __future__ import print_function, division
-from keras.datasets import mnist
-from keras.layers import Input, Dense, Reshape, Flatten, Dropout
-from keras.layers import BatchNormalization, Activation, ZeroPadding2D
-from keras.layers.advanced_activations import LeakyReLU
-from keras.layers.convolutional import UpSampling2D, Conv2D
-from keras.models import Sequential, Model
-from keras.optimizers import Adam
+import tensorflow as keras
+
+import tensorflow as tf
+from tensorflow.keras.datasets import mnist
+from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
+from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
+from tensorflow.keras.layers import LeakyReLU
+from tensorflow.keras.layers import UpSampling2D, Conv2D
+from tensorflow.keras.models import Sequential, Model
+from tensorflow.keras.optimizers import Adam
+
+from lib.virtual_batch import VirtualBatchNormalization
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
@@ -17,7 +22,7 @@ import sys
import numpy as np
class DCGAN():
- def __init__(self, conv_layers = 1):
+ def __init__(self, conv_layers = 1, virtual_batch_normalization=False):
# Input shape
self.img_rows = 28
self.img_cols = 28
@@ -25,6 +30,7 @@ class DCGAN():
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
self.conv_layers = conv_layers
+ self.virtual_batch_normalization = virtual_batch_normalization
optimizer = Adam(0.002, 0.5)
@@ -62,14 +68,21 @@ class DCGAN():
for i in range(self.conv_layers):
model.add(Conv2D(128, kernel_size=3, padding="same"))
- model.add(BatchNormalization())
+ if self.virtual_batch_normalization:
+ model.add(VirtualBatchNormalization())
+ else:
+ model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(UpSampling2D())
for i in range(self.conv_layers):
model.add(Conv2D(64, kernel_size=3, padding="same"))
- model.add(BatchNormalization())
+ if self.virtual_batch_normalization:
+ model.add(VirtualBatchNormalization())
+ else:
+ model.add(BatchNormalization())
+
model.add(Activation("relu"))
model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
@@ -138,14 +151,7 @@ class DCGAN():
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
- if VBN:
- idx = np.random.randint(0, X_train.shape[0], batch_size)
- ref_imgs = X_train[idx]
- mu = np.mean(ref_imgs, axis=0)
- sigma = 1#np.var(ref_imgs, axis=0)
- #need to redefine sigma because of division by zero
- imgs = np.divide(np.subtract(imgs, mu), sigma)
-
+ tf.keras.backend.get_session().run(tf.global_variables_initializer())
# Sample noise and generate a batch of new images
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
gen_imgs = self.generator.predict(noise)
@@ -197,6 +203,6 @@ class DCGAN():
'''
if __name__ == '__main__':
- dcgan = DCGAN()
+ dcgan = DCGAN(virtual_batch_normalization=True)
dcgan.train(epochs=4000, batch_size=32, save_interval=50)
'''