aboutsummaryrefslogtreecommitdiff
path: root/dcgan.py
diff options
context:
space:
mode:
authorVasil Zlatanov <v@skozl.com>2019-03-04 16:41:08 +0000
committerVasil Zlatanov <v@skozl.com>2019-03-04 16:41:08 +0000
commit5ffa17b2381aa1f298f9d9457bda09a2d9907a9b (patch)
treef209962099cd30393a1fd4edd35687794ec4a319 /dcgan.py
parent2ce97744ff9357f74299d6498dd0e5510a7ddb7a (diff)
downloade4-gan-5ffa17b2381aa1f298f9d9457bda09a2d9907a9b.tar.gz
e4-gan-5ffa17b2381aa1f298f9d9457bda09a2d9907a9b.tar.bz2
e4-gan-5ffa17b2381aa1f298f9d9457bda09a2d9907a9b.zip
Add repeateable conv_layers to dcgan
Diffstat (limited to 'dcgan.py')
-rw-r--r--dcgan.py23
1 files changed, 15 insertions, 8 deletions
diff --git a/dcgan.py b/dcgan.py
index b48e99f..a0a26c9 100644
--- a/dcgan.py
+++ b/dcgan.py
@@ -15,13 +15,14 @@ import sys
import numpy as np
class DCGAN():
- def __init__(self):
+ def __init__(self, conv_layers = 1):
# Input shape
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
+ self.conv_layers = conv_layers
optimizer = Adam(0.002, 0.5)
@@ -56,13 +57,19 @@ class DCGAN():
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
model.add(Reshape((7, 7, 128)))
model.add(UpSampling2D())
- model.add(Conv2D(128, kernel_size=3, padding="same"))
- model.add(BatchNormalization())
- model.add(Activation("relu"))
+
+ for i in range(self.conv_layers):
+ model.add(Conv2D(128, kernel_size=3, padding="same"))
+ model.add(BatchNormalization())
+ model.add(Activation("relu"))
+
model.add(UpSampling2D())
- model.add(Conv2D(64, kernel_size=3, padding="same"))
- model.add(BatchNormalization())
- model.add(Activation("relu"))
+
+ for i in range(self.conv_layers):
+ model.add(Conv2D(64, kernel_size=3, padding="same"))
+ model.add(BatchNormalization())
+ model.add(Activation("relu"))
+
model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
model.add(Activation("tanh"))
@@ -183,4 +190,4 @@ class DCGAN():
if __name__ == '__main__':
dcgan = DCGAN()
dcgan.train(epochs=4000, batch_size=32, save_interval=50)
-''' \ No newline at end of file
+'''