aboutsummaryrefslogtreecommitdiff
path: root/cdcgan.py
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-03-14 01:33:07 +0000
committernunzip <np.scarh@gmail.com>2019-03-14 01:33:07 +0000
commit48f3bca39023fe862be6d212ecd47c3f34648260 (patch)
tree06e409d1eddba3261ab1aa5ba475d4ab818eb67d /cdcgan.py
parent2273313e2d818ab26f1c7a0c6bb89d5728611ad7 (diff)
parent00ee8e36064ed247643a68c7fa8591d5a17347d9 (diff)
downloade4-gan-48f3bca39023fe862be6d212ecd47c3f34648260.tar.gz
e4-gan-48f3bca39023fe862be6d212ecd47c3f34648260.tar.bz2
e4-gan-48f3bca39023fe862be6d212ecd47c3f34648260.zip
Merge branch 'master' of skozl.com:e4-gan
Diffstat (limited to 'cdcgan.py')
-rwxr-xr-xcdcgan.py66
1 files changed, 34 insertions, 32 deletions
diff --git a/cdcgan.py b/cdcgan.py
index effc89b..a69dbc8 100755
--- a/cdcgan.py
+++ b/cdcgan.py
@@ -1,8 +1,5 @@
from __future__ import print_function, division
-import tensorflow as keras
-
-import tensorflow as tf
-import tensorflow.keras as keras
+import keras
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Embedding, Activation, ZeroPadding2D
@@ -39,11 +36,12 @@ class CDCGAN():
optimizer=optimizer,
metrics=['accuracy'])
- # Build the generator
- self.generator = self.build_generator()
-
noise = Input(shape=(self.latent_dim,))
label = Input(shape=(1,))
+
+ # Build the generator
+ self.generator = self.build_generator(noise, label)
+
img = self.generator([noise, label])
# For the combined model we will only train the generator
@@ -59,38 +57,44 @@ class CDCGAN():
self.combined.compile(loss=['binary_crossentropy'],
optimizer=optimizer)
- def build_generator(self):
+ def build_generator(self, noise_in, label_in):
+ noise = Dense(7 * 7 * 256)(noise_in)
+ noise = Reshape(target_shape=(7, 7, 256))(noise)
+ noise = Conv2DTranspose(256, kernel_size=3, padding="same")(noise)
+ noise = BatchNormalization()(noise)
+ noise = Activation("relu")(noise)
- model = Sequential()
+ label = Flatten()(Embedding(self.num_classes, self.latent_dim)(label_in))
+ label = Dense(7 * 7 * 256)(label)
+ label = Reshape(target_shape=(7, 7, 256))(label)
+ label = Conv2DTranspose(256, kernel_size=3, padding="same")(label)
+ label = BatchNormalization()(label)
+ label = Activation("relu")(label)
- model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
- model.add(Reshape((7, 7, 128)))
+ # Combine the two
- model.add(Conv2DTranspose(256, kernel_size=3, padding="same", strides=(2,2)))
- model.add(BatchNormalization())
- model.add(Activation("relu"))
+ x = keras.layers.Concatenate()([noise, label])
- model.add(Conv2DTranspose(128, kernel_size=3, padding="same", strides=(2,2)))
- model.add(BatchNormalization())
- model.add(Activation("relu"))
+ x = Conv2DTranspose(256, kernel_size=3, padding="same")(x)
+ x = BatchNormalization()(x)
+ x = Activation("relu")(x)
- model.add(Conv2DTranspose(64, kernel_size=3, padding="same"))
- model.add(BatchNormalization())
- model.add(Activation("relu"))
+ x = Conv2DTranspose(128, kernel_size=3, padding="same", strides=(2,2))(x)
+ x = BatchNormalization()(x)
+ x = Activation("relu")(x)
+
+ x = Conv2DTranspose(64, kernel_size=3, padding="same", strides=(2,2))(x)
+ x = BatchNormalization()(x)
+ x = Activation("relu")(x)
- model.add(Conv2DTranspose(1, kernel_size=3, padding="same"))
- model.add(Activation("tanh"))
+ x = (Conv2DTranspose(1, kernel_size=3, padding="same"))(x)
+ x = Activation("tanh")(x)
+ model = Model([noise_in, label_in], outputs=x)
- noise = Input(shape=(self.latent_dim,))
- label = Input(shape=(1,), dtype='int32')
- label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
- model_input = multiply([noise, label_embedding])
- img = model(model_input)
-
- #model.summary()
+ model.summary()
- return Model([noise, label], img)
+ return model
def build_discriminator(self):
@@ -242,8 +246,6 @@ class CDCGAN():
return train_data, test_data, val_data, labels_train, labels_test, labels_val
-'''
if __name__ == '__main__':
cdcgan = CDCGAN()
cdcgan.train(epochs=4000, batch_size=32)
-'''