aboutsummaryrefslogtreecommitdiff
path: root/cdcgan.py
diff options
context:
space:
mode:
Diffstat (limited to 'cdcgan.py')
-rwxr-xr-xcdcgan.py27
1 files changed, 13 insertions, 14 deletions
diff --git a/cdcgan.py b/cdcgan.py
index 8d59a03..effc89b 100755
--- a/cdcgan.py
+++ b/cdcgan.py
@@ -7,7 +7,7 @@ from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Embedding, Activation, ZeroPadding2D
from keras.layers import LeakyReLU
-from keras.layers import UpSampling2D, Conv2D
+from keras.layers import UpSampling2D, Conv2D, Conv2DTranspose
from keras.models import Sequential, Model
from keras.optimizers import Adam
@@ -65,25 +65,22 @@ class CDCGAN():
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
model.add(Reshape((7, 7, 128)))
- model.add(UpSampling2D())
- for i in range(self.conv_layers):
- model.add(Conv2D(128, kernel_size=3, padding="same"))
- model.add(BatchNormalization())
- model.add(Activation("relu"))
-
- model.add(UpSampling2D())
+ model.add(Conv2DTranspose(256, kernel_size=3, padding="same", strides=(2,2)))
+ model.add(BatchNormalization())
+ model.add(Activation("relu"))
- for i in range(self.conv_layers):
- model.add(Conv2D(64, kernel_size=3, padding="same"))
- model.add(BatchNormalization())
+ model.add(Conv2DTranspose(128, kernel_size=3, padding="same", strides=(2,2)))
+ model.add(BatchNormalization())
+ model.add(Activation("relu"))
- model.add(Activation("relu"))
+ model.add(Conv2DTranspose(64, kernel_size=3, padding="same"))
+ model.add(BatchNormalization())
+ model.add(Activation("relu"))
- model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
+ model.add(Conv2DTranspose(1, kernel_size=3, padding="same"))
model.add(Activation("tanh"))
- #model.summary()
noise = Input(shape=(self.latent_dim,))
label = Input(shape=(1,), dtype='int32')
@@ -91,6 +88,8 @@ class CDCGAN():
model_input = multiply([noise, label_embedding])
img = model(model_input)
+ #model.summary()
+
return Model([noise, label], img)
def build_discriminator(self):