diff options
author | Vasil Zlatanov <v@skozl.com> | 2019-03-07 15:00:40 +0000 |
---|---|---|
committer | Vasil Zlatanov <v@skozl.com> | 2019-03-07 15:00:40 +0000 |
commit | 23fa20a9a8e8dc34410c400545ef182b0552e72a (patch) | |
tree | 769d6f9628c7ab15c8bfe38211d07b150e61e372 | |
parent | c9958b93e9d2e2ea9b7e7556a02736835f905df4 (diff) | |
download | e4-gan-23fa20a9a8e8dc34410c400545ef182b0552e72a.tar.gz e4-gan-23fa20a9a8e8dc34410c400545ef182b0552e72a.tar.bz2 e4-gan-23fa20a9a8e8dc34410c400545ef182b0552e72a.zip |
Rewrite cdcgan
-rwxr-xr-x | cdcgan.py | 171 |
1 files changed, 93 insertions, 78 deletions
@@ -1,45 +1,46 @@ from __future__ import print_function, division -import tensorflow.keras as keras +import tensorflow as keras + import tensorflow as tf -from keras.datasets import mnist - -import keras.layers as layers -from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply -from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D -from keras.layers.advanced_activations import LeakyReLU -from keras.layers.convolutional import UpSampling2D, Conv2D -from keras.models import Sequential, Model -from keras.optimizers import Adam +from tensorflow.keras.datasets import mnist +from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply +from tensorflow.keras.layers import BatchNormalization, Embedding, Activation, ZeroPadding2D +from tensorflow.keras.layers import LeakyReLU +from tensorflow.keras.layers import UpSampling2D, Conv2D +from tensorflow.keras.models import Sequential, Model +from tensorflow.keras.optimizers import Adam + import matplotlib.pyplot as plt -from IPython.display import clear_output +import matplotlib.gridspec as gridspec + from tqdm import tqdm +import sys + import numpy as np class CDCGAN(): - def __init__(self): + def __init__(self, conv_layers = 1, num_classes = 10): # Input shape + self.num_classes = num_classes self.img_rows = 28 self.img_cols = 28 self.channels = 1 self.img_shape = (self.img_rows, self.img_cols, self.channels) - self.num_classes = 10 self.latent_dim = 100 + self.conv_layers = conv_layers - optimizer = Adam(0.0002, 0.5) + optimizer = Adam(0.002, 0.5) # Build and compile the discriminator self.discriminator = self.build_discriminator() - - self.discriminator.compile(loss=['binary_crossentropy'], + self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) # Build the generator self.generator = self.build_generator() - # The generator takes noise and the target label as input - # and generates the corresponding digit of that label noise = Input(shape=(self.latent_dim,)) label = Input(shape=(1,)) img = self.generator([noise, label]) @@ -47,10 +48,10 @@ class CDCGAN(): # For the combined model we will only train the generator self.discriminator.trainable = False - # The discriminator takes generated image as input and determines validity - # and the label of that image + # The discriminator takes generated images as input and determines validity valid = self.discriminator([img, label]) + # The combined model (stacked generator and discriminator) # Trains generator to fool discriminator self.combined = Model([noise, label], valid) @@ -58,61 +59,78 @@ class CDCGAN(): optimizer=optimizer) def build_generator(self): - # Prepare noise input - input_z = layers.Input((100,)) - dense_z_1 = layers.Dense(1024)(input_z) - act_z_1 = layers.Activation("tanh")(dense_z_1) - dense_z_2 = layers.Dense(128 * 7 * 7)(act_z_1) - bn_z_1 = layers.BatchNormalization()(dense_z_2) - reshape_z = layers.Reshape((7, 7, 128), input_shape=(128 * 7 * 7,))(bn_z_1) - - # Prepare Conditional (label) input - input_c = layers.Input((1,)) - dense_c_1 = layers.Dense(1024)(input_c) - act_c_1 = layers.Activation("tanh")(dense_c_1) - dense_c_2 = layers.Dense(128 * 7 * 7)(act_c_1) - bn_c_1 = layers.BatchNormalization()(dense_c_2) - reshape_c = layers.Reshape((7, 7, 128), input_shape=(128 * 7 * 7,))(bn_c_1) - - # Combine input source - concat_z_c = layers.Concatenate()([reshape_z, reshape_c]) - - # Image generation with the concatenated inputs - up_1 = layers.UpSampling2D(size=(2, 2))(concat_z_c) - conv_1 = layers.Conv2D(64, (5, 5), padding='same')(up_1) - act_1 = layers.Activation("tanh")(conv_1) - up_2 = layers.UpSampling2D(size=(2, 2))(act_1) - conv_2 = layers.Conv2D(1, (5, 5), padding='same')(up_2) - act_2 = layers.Activation("tanh")(conv_2) - model = Model(inputs=[input_z, input_c], outputs=act_2) - return model + model = Sequential() + + model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim)) + model.add(Reshape((7, 7, 128))) + model.add(UpSampling2D()) + + for i in range(self.conv_layers): + model.add(Conv2D(128, kernel_size=3, padding="same")) + model.add(BatchNormalization()) + model.add(Activation("relu")) + + model.add(UpSampling2D()) + + for i in range(self.conv_layers): + model.add(Conv2D(64, kernel_size=3, padding="same")) + model.add(BatchNormalization()) + + model.add(Activation("relu")) + + model.add(Conv2D(self.channels, kernel_size=3, padding="same")) + model.add(Activation("tanh")) + + #model.summary() + + noise = Input(shape=(self.latent_dim,)) + label = Input(shape=(1,), dtype='int32') + label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label)) + model_input = multiply([noise, label_embedding]) + img = model(model_input) + + return Model([noise, label], img) def build_discriminator(self): - input_gen_image = layers.Input((28, 28, 1)) - conv_1_image = layers.Conv2D(64, (5, 5), padding='same')(input_gen_image) - act_1_image = layers.Activation("tanh")(conv_1_image) - pool_1_image = layers.MaxPooling2D(pool_size=(2, 2))(act_1_image) - conv_2_image = layers.Conv2D(128, (5, 5))(pool_1_image) - act_2_image = layers.Activation("tanh")(conv_2_image) - pool_2_image = layers.MaxPooling2D(pool_size=(2, 2))(act_2_image) - - input_c = layers.Input((1,)) - dense_1_c = layers.Dense(1024)(input_c) - act_1_c = layers.Activation("tanh")(dense_1_c) - dense_2_c = layers.Dense(5 * 5 * 128)(act_1_c) - bn_c = layers.BatchNormalization()(dense_2_c) - reshaped_c = layers.Reshape((5, 5, 128))(bn_c) - - concat = layers.Concatenate()([pool_2_image, reshaped_c]) - - flat = layers.Flatten()(concat) - dense_1 = layers.Dense(1024)(flat) - act_1 = layers.Activation("tanh")(dense_1) - dense_2 = layers.Dense(1)(act_1) - act_2 = layers.Activation('sigmoid')(dense_2) - model = Model(inputs=[input_gen_image, input_c], outputs=act_2) - return model + + model = Sequential() + + model.add(Dense(28 * 28 * 3, activation="relu")) + model.add(Reshape((28, 28, 3))) + model.add(Conv2D(32, kernel_size=3, strides=2, padding="same")) + model.add(LeakyReLU(alpha=0.2)) + model.add(Dropout(0.25)) + model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) + model.add(ZeroPadding2D(padding=((0,1),(0,1)))) + model.add(BatchNormalization()) + model.add(LeakyReLU(alpha=0.2)) + model.add(Dropout(0.25)) + model.add(Conv2D(128, kernel_size=3, strides=2, padding="same")) + model.add(BatchNormalization()) + model.add(LeakyReLU(alpha=0.2)) + model.add(Dropout(0.25)) + model.add(Conv2D(256, kernel_size=3, strides=1, padding="same")) + model.add(BatchNormalization()) + model.add(LeakyReLU(alpha=0.2)) + model.add(Dropout(0.25)) + model.add(Flatten()) + model.add(Dense(1, activation='sigmoid')) + + #model.summary() + + img = Input(shape=self.img_shape) + + label = Input(shape=(1,), dtype='int32') + + label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label)) + flat_img = Flatten()(img) + + model_input = multiply([flat_img, label_embedding]) + + validity = model(model_input) + + return Model([img, label], validity) def train(self, epochs, batch_size=128, sample_interval=50, graph=False, smooth_real=1, smooth_fake=0): @@ -143,6 +161,7 @@ class CDCGAN(): # Sample noise as generator input noise = np.random.normal(0, 1, (batch_size, 100)) + tf.keras.backend.get_session().run(tf.global_variables_initializer()) # Generate a half batch of new images gen_imgs = self.generator.predict([noise, labels]) @@ -224,10 +243,6 @@ class CDCGAN(): return train_data, test_data, val_data, labels_train, labels_test, labels_val -''' if __name__ == '__main__': - cgan = CDCGAN() - cgan.train(epochs=70, batch_size=32, sample_interval=200) - train, test, tr_labels, te_labels = cgan.generate_data() - print(train.shape, test.shape) -''' + cdcgan = CDCGAN() + cdcgan.train(epochs=4000, batch_size=32) |