aboutsummaryrefslogtreecommitdiff
path: root/dcgan.py
diff options
context:
space:
mode:
Diffstat (limited to 'dcgan.py')
-rw-r--r--dcgan.py36
1 files changed, 11 insertions, 25 deletions
diff --git a/dcgan.py b/dcgan.py
index 21afaac..347f61e 100644
--- a/dcgan.py
+++ b/dcgan.py
@@ -1,16 +1,11 @@
from __future__ import print_function, division
-import tensorflow as keras
-
-import tensorflow as tf
-from tensorflow.keras.datasets import mnist
-from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
-from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
-from tensorflow.keras.layers import LeakyReLU
-from tensorflow.keras.layers import UpSampling2D, Conv2D
-from tensorflow.keras.models import Sequential, Model
-from tensorflow.keras.optimizers import Adam
-
-from lib.virtual_batch import VirtualBatchNormalization
+from keras.datasets import mnist
+from keras.layers import Input, Dense, Reshape, Flatten, Dropout
+from keras.layers import BatchNormalization, Activation, ZeroPadding2D
+from keras.layers.advanced_activations import LeakyReLU
+from keras.layers.convolutional import UpSampling2D, Conv2D
+from keras.models import Sequential, Model
+from keras.optimizers import Adam
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
@@ -22,7 +17,7 @@ import sys
import numpy as np
class DCGAN():
- def __init__(self, conv_layers = 1, virtual_batch_normalization=False):
+ def __init__(self, conv_layers = 1):
# Input shape
self.img_rows = 28
self.img_cols = 28
@@ -30,7 +25,6 @@ class DCGAN():
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
self.conv_layers = conv_layers
- self.virtual_batch_normalization = virtual_batch_normalization
optimizer = Adam(0.002, 0.5)
@@ -68,21 +62,14 @@ class DCGAN():
for i in range(self.conv_layers):
model.add(Conv2D(128, kernel_size=3, padding="same"))
- if self.virtual_batch_normalization:
- model.add(VirtualBatchNormalization())
- else:
- model.add(BatchNormalization())
+ model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(UpSampling2D())
for i in range(self.conv_layers):
model.add(Conv2D(64, kernel_size=3, padding="same"))
- if self.virtual_batch_normalization:
- model.add(VirtualBatchNormalization())
- else:
- model.add(BatchNormalization())
-
+ model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
@@ -151,7 +138,6 @@ class DCGAN():
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
- tf.keras.backend.get_session().run(tf.global_variables_initializer())
# Sample noise and generate a batch of new images
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
gen_imgs = self.generator.predict(noise)
@@ -203,6 +189,6 @@ class DCGAN():
'''
if __name__ == '__main__':
- dcgan = DCGAN(virtual_batch_normalization=True)
+ dcgan = DCGAN()
dcgan.train(epochs=4000, batch_size=32, save_interval=50)
'''