aboutsummaryrefslogtreecommitdiff
path: root/cgan.py
diff options
context:
space:
mode:
Diffstat (limited to 'cgan.py')
-rw-r--r--[-rwxr-xr-x]cgan.py28
1 files changed, 17 insertions, 11 deletions
diff --git a/cgan.py b/cgan.py
index 5ab0c10..b9928f0 100755..100644
--- a/cgan.py
+++ b/cgan.py
@@ -1,21 +1,23 @@
from __future__ import print_function, division
import tensorflow.keras as keras
import tensorflow as tf
-from keras.datasets import mnist
-from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
-from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
-from keras.layers.advanced_activations import LeakyReLU
-from keras.layers.convolutional import UpSampling2D, Conv2D
-from keras.models import Sequential, Model
-from keras.optimizers import Adam
+from tensorflow.keras.datasets import mnist
+from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
+from tensorflow.keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
+from tensorflow.keras.layers import LeakyReLU
+from tensorflow.keras.layers import UpSampling2D, Conv2D
+from tensorflow.keras.models import Sequential, Model
+from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from IPython.display import clear_output
from tqdm import tqdm
+from lib.virtual_batch import VirtualBatchNormalization
+
import numpy as np
class CGAN():
- def __init__(self, dense_layers = 3):
+ def __init__(self, dense_layers = 3, virtual_batch_normalization=False):
# Input shape
self.img_rows = 28
self.img_cols = 28
@@ -24,6 +26,7 @@ class CGAN():
self.num_classes = 10
self.latent_dim = 100
self.dense_layers = dense_layers
+ self.virtual_batch_normalization = virtual_batch_normalization
optimizer = Adam(0.0002, 0.5)
@@ -63,7 +66,10 @@ class CGAN():
output_size = 2**(8+i)
model.add(Dense(output_size, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
- model.add(BatchNormalization(momentum=0.8))
+ if self.virtual_batch_normalization:
+ model.add(VirtualBatchNormalization(momentum=0.8))
+ else:
+ model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
@@ -136,6 +142,7 @@ class CGAN():
# Sample noise as generator input
noise = np.random.normal(0, 1, (batch_size, 100))
+ tf.keras.backend.get_session().run(tf.global_variables_initializer())
# Generate a half batch of new images
gen_imgs = self.generator.predict([noise, labels])
@@ -217,10 +224,9 @@ class CGAN():
return train_data, test_data, val_data, labels_train, labels_test, labels_val
-
'''
if __name__ == '__main__':
- cgan = CGAN(dense_layers=1)
+ cgan = CGAN(dense_layers=1, virtual_batch_normalization=True)
cgan.train(epochs=7000, batch_size=32, sample_interval=200)
train, test, tr_labels, te_labels = cgan.generate_data()
print(train.shape, test.shape)