aboutsummaryrefslogtreecommitdiff
path: root/dcgan.py
diff options
context:
space:
mode:
Diffstat (limited to 'dcgan.py')
-rw-r--r--dcgan.py34
1 files changed, 24 insertions, 10 deletions
diff --git a/dcgan.py b/dcgan.py
index bc7e14e..eca1852 100644
--- a/dcgan.py
+++ b/dcgan.py
@@ -1,11 +1,15 @@
from __future__ import print_function, division
-from keras.datasets import mnist
-from keras.layers import Input, Dense, Reshape, Flatten, Dropout
-from keras.layers import BatchNormalization, Activation, ZeroPadding2D
-from keras.layers.advanced_activations import LeakyReLU
-from keras.layers.convolutional import UpSampling2D, Conv2D
-from keras.models import Sequential, Model
-from keras.optimizers import Adam
+import tensorflow as keras
+
+from tensorflow.keras.datasets import mnist
+from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
+from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
+from tensorflow.keras.layers.advanced_activations import LeakyReLU
+from tensorflow.keras.layers.convolutional import UpSampling2D, Conv2D
+from tensorflow.keras.models import Sequential, Model
+from tensorflow.keras.optimizers import Adam
+
+from lib/virtual_batch import VirtualBatchNormalization
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
@@ -17,7 +21,7 @@ import sys
import numpy as np
class DCGAN():
- def __init__(self, conv_layers = 1):
+ def __init__(self, conv_layers = 1, virtual_batch_normalization=False):
# Input shape
self.img_rows = 28
self.img_cols = 28
@@ -25,6 +29,7 @@ class DCGAN():
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
self.conv_layers = conv_layers
+ self.virtual_batch_normalization = virtual_batch_normalization
optimizer = Adam(0.002, 0.5)
@@ -62,14 +67,21 @@ class DCGAN():
for i in range(self.conv_layers):
model.add(Conv2D(128, kernel_size=3, padding="same"))
- model.add(BatchNormalization())
+ if self.virtual_batch_normalization:
+ model.add(VirtualBatchNormalization())
+ else:
+ model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(UpSampling2D())
for i in range(self.conv_layers):
model.add(Conv2D(64, kernel_size=3, padding="same"))
- model.add(BatchNormalization())
+ if self.virtual_batch_normalization:
+ model.add(VirtualBatchNormalization())
+ else:
+ model.add(BatchNormalization())
+
model.add(Activation("relu"))
model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
@@ -137,6 +149,8 @@ class DCGAN():
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
+ tf.keras.backend.get_session().run(tf.global_variables_initializer())
+
# Sample noise and generate a batch of new images
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
gen_imgs = self.generator.predict(noise)