from __future__ import print_function, division import tensorflow as keras import tensorflow as tf from tensorflow.keras.datasets import mnist from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply from tensorflow.keras.layers import BatchNormalization, Embedding, Activation, ZeroPadding2D from tensorflow.keras.layers import LeakyReLU from tensorflow.keras.layers import UpSampling2D, Conv2D from tensorflow.keras.models import Sequential, Model from tensorflow.keras.optimizers import Adam import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec from tqdm import tqdm import sys import numpy as np class CDCGAN(): def __init__(self, conv_layers = 1, num_classes = 10): # Input shape self.num_classes = num_classes self.img_rows = 28 self.img_cols = 28 self.channels = 1 self.img_shape = (self.img_rows, self.img_cols, self.channels) self.latent_dim = 100 self.conv_layers = conv_layers optimizer = Adam(0.002, 0.5) # Build and compile the discriminator self.discriminator = self.build_discriminator() self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) # Build the generator self.generator = self.build_generator() noise = Input(shape=(self.latent_dim,)) label = Input(shape=(1,)) img = self.generator([noise, label]) # For the combined model we will only train the generator self.discriminator.trainable = False # The discriminator takes generated images as input and determines validity valid = self.discriminator([img, label]) # The combined model (stacked generator and discriminator) # Trains generator to fool discriminator self.combined = Model([noise, label], valid) self.combined.compile(loss=['binary_crossentropy'], optimizer=optimizer) def build_generator(self): model = Sequential() model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim)) model.add(Reshape((7, 7, 128))) model.add(UpSampling2D()) for i in range(self.conv_layers): model.add(Conv2D(128, kernel_size=3, padding="same")) model.add(BatchNormalization()) model.add(Activation("relu")) model.add(UpSampling2D()) for i in range(self.conv_layers): model.add(Conv2D(64, kernel_size=3, padding="same")) model.add(BatchNormalization()) model.add(Activation("relu")) model.add(Conv2D(self.channels, kernel_size=3, padding="same")) model.add(Activation("tanh")) #model.summary() noise = Input(shape=(self.latent_dim,)) label = Input(shape=(1,), dtype='int32') label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label)) model_input = multiply([noise, label_embedding]) img = model(model_input) return Model([noise, label], img) def build_discriminator(self): model = Sequential() model.add(Dense(28 * 28 * 3, activation="relu")) model.add(Reshape((28, 28, 3))) model.add(Conv2D(32, kernel_size=3, strides=2, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) model.add(ZeroPadding2D(padding=((0,1),(0,1)))) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(128, kernel_size=3, strides=2, padding="same")) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(256, kernel_size=3, strides=1, padding="same")) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(1, activation='sigmoid')) #model.summary() img = Input(shape=self.img_shape) label = Input(shape=(1,), dtype='int32') label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label)) flat_img = Flatten()(img) model_input = multiply([flat_img, label_embedding]) validity = model(model_input) return Model([img, label], validity) def train(self, epochs, batch_size=128, sample_interval=50, graph=False, smooth_real=1, smooth_fake=0): # Load the dataset (X_train, y_train), (_, _) = mnist.load_data() # Configure input X_train = (X_train.astype(np.float32) - 127.5) / 127.5 X_train = np.expand_dims(X_train, axis=3) y_train = y_train.reshape(-1, 1) # Adversarial ground truths valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) xaxis = np.arange(epochs) loss = np.zeros((2,epochs)) for epoch in tqdm(range(epochs)): # --------------------- # Train Discriminator # --------------------- # Select a random half batch of images idx = np.random.randint(0, X_train.shape[0], batch_size) imgs, labels = X_train[idx], y_train[idx] # Sample noise as generator input noise = np.random.normal(0, 1, (batch_size, 100)) tf.keras.backend.get_session().run(tf.global_variables_initializer()) # Generate a half batch of new images gen_imgs = self.generator.predict([noise, labels]) # Train the discriminator d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid*smooth_real) d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], valid*smooth_fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # Train Generator # --------------------- # Condition on labels sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1) # Train the generator g_loss = self.combined.train_on_batch([noise, sampled_labels], valid) # Plot the progress #print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)) loss[0][epoch] = d_loss[0] loss[1][epoch] = g_loss # If at save interval => save generated image samples if epoch % sample_interval == 0: self.sample_images(epoch) if graph: plt.plot(xaxis,loss[0]) plt.plot(xaxis,loss[1]) plt.legend(('Discriminator', 'Generator'), loc='best') plt.xlabel('Epoch') plt.ylabel('Binary Crossentropy Loss') def sample_images(self, epoch): r, c = 2, 5 noise = np.random.normal(0, 1, (r * c, 100)) sampled_labels = np.arange(0, 10).reshape(-1, 1) #using dummy_labels would just print zeros to help identify image quality #dummy_labels = np.zeros(32).reshape(-1, 1) gen_imgs = self.generator.predict([noise, sampled_labels]) # Rescale images 0 - 1 gen_imgs = 0.5 * gen_imgs + 0.5 fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray') axs[i,j].set_title("Digit: %d" % sampled_labels[cnt]) axs[i,j].axis('off') cnt += 1 fig.savefig("images/%d.png" % epoch) plt.close() def generate_data(self): noise_train = np.random.normal(0, 1, (55000, 100)) noise_test = np.random.normal(0, 1, (10000, 100)) noise_val = np.random.normal(0, 1, (5000, 100)) labels_train = np.zeros(55000).reshape(-1, 1) labels_test = np.zeros(10000).reshape(-1, 1) labels_val = np.zeros(5000).reshape(-1, 1) for i in range(10): labels_train[i*5500:] = i labels_test[i*1000:] = i labels_val[i*500:] = i train_data = self.generator.predict([noise_train, labels_train]) test_data = self.generator.predict([noise_test, labels_test]) val_data = self.generator.predict([noise_val, labels_val]) labels_train = keras.utils.to_categorical(labels_train, 10) labels_test = keras.utils.to_categorical(labels_test, 10) labels_val = keras.utils.to_categorical(labels_val, 10) return train_data, test_data, val_data, labels_train, labels_test, labels_val if __name__ == '__main__': cdcgan = CDCGAN() cdcgan.train(epochs=4000, batch_size=32)