diff options
author | Vasil Zlatanov <v@skozl.com> | 2019-03-04 16:26:35 +0000 |
---|---|---|
committer | Vasil Zlatanov <v@skozl.com> | 2019-03-04 16:26:35 +0000 |
commit | 2ce97744ff9357f74299d6498dd0e5510a7ddb7a (patch) | |
tree | c9470e5c5424bd2fbb788d30d53ebd7933f09cae | |
parent | 7a7818bda1a72bd1819bc7803cdc02b3af1e103d (diff) | |
download | e4-gan-2ce97744ff9357f74299d6498dd0e5510a7ddb7a.tar.gz e4-gan-2ce97744ff9357f74299d6498dd0e5510a7ddb7a.tar.bz2 e4-gan-2ce97744ff9357f74299d6498dd0e5510a7ddb7a.zip |
Add dense_layers argument for CGAN()
-rwxr-xr-x[-rw-r--r--] | cgan.py | 21 |
1 files changed, 10 insertions, 11 deletions
@@ -9,11 +9,12 @@ from keras.layers.convolutional import UpSampling2D, Conv2D from keras.models import Sequential, Model from keras.optimizers import Adam import matplotlib.pyplot as plt +from IPython.display import clear_output import numpy as np class CGAN(): - def __init__(self): + def __init__(self, dense_layers = 3): # Input shape self.img_rows = 28 self.img_cols = 28 @@ -21,6 +22,7 @@ class CGAN(): self.img_shape = (self.img_rows, self.img_cols, self.channels) self.num_classes = 10 self.latent_dim = 100 + self.dense_layers = dense_layers optimizer = Adam(0.0002, 0.5) @@ -56,15 +58,12 @@ class CGAN(): model = Sequential() - model.add(Dense(256, input_dim=self.latent_dim)) - model.add(LeakyReLU(alpha=0.2)) - model.add(BatchNormalization(momentum=0.8)) - model.add(Dense(512)) - model.add(LeakyReLU(alpha=0.2)) - model.add(BatchNormalization(momentum=0.8)) - model.add(Dense(1024)) - model.add(LeakyReLU(alpha=0.2)) - model.add(BatchNormalization(momentum=0.8)) + for i in range(self.dense_layers): + output_size = 2**(8+i) + model.add(Dense(output_size, input_dim=self.latent_dim)) + model.add(LeakyReLU(alpha=0.2)) + model.add(BatchNormalization(momentum=0.8)) + model.add(Dense(np.prod(self.img_shape), activation='tanh')) model.add(Reshape(self.img_shape)) @@ -223,7 +222,7 @@ class CGAN(): ''' if __name__ == '__main__': - cgan = CGAN() + cgan = CGAN(dense_layers=1) cgan.train(epochs=7000, batch_size=32, sample_interval=200) train, test, tr_labels, te_labels = cgan.generate_data() print(train.shape, test.shape) |