aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVasil Zlatanov <v@skozl.com>2019-03-04 16:26:35 +0000
committerVasil Zlatanov <v@skozl.com>2019-03-04 16:26:35 +0000
commit2ce97744ff9357f74299d6498dd0e5510a7ddb7a (patch)
treec9470e5c5424bd2fbb788d30d53ebd7933f09cae
parent7a7818bda1a72bd1819bc7803cdc02b3af1e103d (diff)
downloade4-gan-2ce97744ff9357f74299d6498dd0e5510a7ddb7a.tar.gz
e4-gan-2ce97744ff9357f74299d6498dd0e5510a7ddb7a.tar.bz2
e4-gan-2ce97744ff9357f74299d6498dd0e5510a7ddb7a.zip
Add dense_layers argument for CGAN()
-rwxr-xr-x[-rw-r--r--]cgan.py21
1 files changed, 10 insertions, 11 deletions
diff --git a/cgan.py b/cgan.py
index 3c89121..68256f3 100644..100755
--- a/cgan.py
+++ b/cgan.py
@@ -9,11 +9,12 @@ from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
+from IPython.display import clear_output
import numpy as np
class CGAN():
- def __init__(self):
+ def __init__(self, dense_layers = 3):
# Input shape
self.img_rows = 28
self.img_cols = 28
@@ -21,6 +22,7 @@ class CGAN():
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.num_classes = 10
self.latent_dim = 100
+ self.dense_layers = dense_layers
optimizer = Adam(0.0002, 0.5)
@@ -56,15 +58,12 @@ class CGAN():
model = Sequential()
- model.add(Dense(256, input_dim=self.latent_dim))
- model.add(LeakyReLU(alpha=0.2))
- model.add(BatchNormalization(momentum=0.8))
- model.add(Dense(512))
- model.add(LeakyReLU(alpha=0.2))
- model.add(BatchNormalization(momentum=0.8))
- model.add(Dense(1024))
- model.add(LeakyReLU(alpha=0.2))
- model.add(BatchNormalization(momentum=0.8))
+ for i in range(self.dense_layers):
+ output_size = 2**(8+i)
+ model.add(Dense(output_size, input_dim=self.latent_dim))
+ model.add(LeakyReLU(alpha=0.2))
+ model.add(BatchNormalization(momentum=0.8))
+
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
@@ -223,7 +222,7 @@ class CGAN():
'''
if __name__ == '__main__':
- cgan = CGAN()
+ cgan = CGAN(dense_layers=1)
cgan.train(epochs=7000, batch_size=32, sample_interval=200)
train, test, tr_labels, te_labels = cgan.generate_data()
print(train.shape, test.shape)