from __future__ import print_function, division import tensorflow as keras import tensorflow as tf import tensorflow.keras as keras from keras.datasets import mnist from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply from keras.layers import BatchNormalization, Embedding, Activation, ZeroPadding2D from keras.layers import LeakyReLU from keras.layers import UpSampling2D, Conv2D, Conv2DTranspose from keras.models import Sequential, Model from keras.optimizers import Adam import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec from tqdm import tqdm import sys import numpy as np class CDCGAN(): def __init__(self, conv_layers = 1, num_classes = 10): # Input shape self.num_classes = num_classes self.img_rows = 28 self.img_cols = 28 self.channels = 1 self.img_shape = (self.img_rows, self.img_cols, self.channels) self.latent_dim = 100 self.conv_layers = conv_layers optimizer = Adam(0.002, 0.5) # Build and compile the discriminator self.discriminator = self.build_discriminator() self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) # Build the generator self.generator = self.build_generator() noise = Input(shape=(self.latent_dim,)) label = Input(shape=(1,)) img = self.generator([noise, label]) # For the combined model we will only train the generator self.discriminator.trainable = False # The discriminator takes generated images as input and determines validity valid = self.discriminator([img, label]) # The combined model (stacked generator and discriminator) # Trains generator to fool discriminator self.combined = Model([noise, label], valid) self.combined.compile(loss=['binary_crossentropy'], optimizer=optimizer) def build_generator(self): model = Sequential() model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim)) model.add(Reshape((7, 7, 128))) model.add(Conv2DTranspose(256, kernel_size=3, padding="same", strides=(2,2))) model.add(BatchNormalization()) model.add(Activation("relu")) model.add(Conv2DTranspose(128, kernel_size=3, padding="same", strides=(2,2))) model.add(BatchNormalization()) model.add(Activation("relu")) model.add(Conv2DTranspose(64, kernel_size=3, padding="same")) model.add(BatchNormalization()) model.add(Activation("relu")) model.add(Conv2DTranspose(1, kernel_size=3, padding="same")) model.add(Activation("tanh")) noise = Input(shape=(self.latent_dim,)) label = Input(shape=(1,), dtype='int32') label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label)) model_input = multiply([noise, label_embedding]) img = model(model_input) #model.summary() return Model([noise, label], img) def build_discriminator(self): model = Sequential() model.add(Dense(28 * 28 * 3, activation="relu")) model.add(Reshape((28, 28, 3))) model.add(Conv2D(32, kernel_size=3, strides=2, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) model.add(ZeroPadding2D(padding=((0,1),(0,1)))) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(128, kernel_size=3, strides=2, padding="same")) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(256, kernel_size=3, strides=1, padding="same")) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(1, activation='sigmoid')) #model.summary() img = Input(shape=self.img_shape) label = Input(shape=(1,), dtype='int32') label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label)) flat_img = Flatten()(img) model_input = multiply([flat_img, label_embedding]) validity = model(model_input) return Model([img, label], validity) def train(self, epochs, batch_size=128, sample_interval=50, graph=False, smooth_real=1, smooth_fake=0): # Load the dataset (X_train, y_train), (_, _) = mnist.load_data() # Configure input X_train = (X_train.astype(np.float32) - 127.5) / 127.5 X_train = np.expand_dims(X_train, axis=3) y_train = y_train.reshape(-1, 1) # Adversarial ground truths valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) xaxis = np.arange(epochs) loss = np.zeros((2,epochs)) for epoch in tqdm(range(epochs)): # --------------------- # Train Discriminator # --------------------- # Select a random half batch of images idx = np.random.randint(0, X_train.shape[0], batch_size) imgs, labels = X_train[idx], y_train[idx] # Sample noise as generator input noise = np.random.normal(0, 1, (batch_size, 100)) # Generate a half batch of new images gen_imgs = self.generator.predict([noise, labels]) # Train the discriminator d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid*smooth_real) d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], valid*smooth_fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # Train Generator # --------------------- # Condition on labels sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1) # Train the generator g_loss = self.combined.train_on_batch([noise, sampled_labels], valid) # Plot the progress #print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)) loss[0][epoch] = d_loss[0] loss[1][epoch] = g_loss # If at save interval => save generated image samples if epoch % sample_interval == 0: self.sample_images(epoch) if graph: plt.plot(xaxis,loss[0]) plt.plot(xaxis,loss[1]) plt.legend(('Discriminator', 'Generator'), loc='best') plt.xlabel('Epoch') plt.ylabel('Binary Crossentropy Loss') def sample_images(self, epoch): r, c = 2, 5 noise = np.random.normal(0, 1, (r * c, 100)) sampled_labels = np.arange(0, 10).reshape(-1, 1) #using dummy_labels would just print zeros to help identify image quality #dummy_labels = np.zeros(32).reshape(-1, 1) gen_imgs = self.generator.predict([noise, sampled_labels]) # Rescale images 0 - 1 gen_imgs = 0.5 * gen_imgs + 0.5 fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray') axs[i,j].set_title("Digit: %d" % sampled_labels[cnt]) axs[i,j].axis('off') cnt += 1 fig.savefig("images/%d.png" % epoch) plt.close() def generate_data(self): noise_train = np.random.normal(0, 1, (55000, 100)) noise_test = np.random.normal(0, 1, (10000, 100)) noise_val = np.random.normal(0, 1, (5000, 100)) labels_train = np.zeros(55000).reshape(-1, 1) labels_test = np.zeros(10000).reshape(-1, 1) labels_val = np.zeros(5000).reshape(-1, 1) for i in range(10): labels_train[i*5500:-1] = i labels_test[i*1000:-1] = i labels_val[i*500:-1] = i train_data = self.generator.predict([noise_train, labels_train]) test_data = self.generator.predict([noise_test, labels_test]) val_data = self.generator.predict([noise_val, labels_val]) labels_train = keras.utils.to_categorical(labels_train, 10) labels_test = keras.utils.to_categorical(labels_test, 10) labels_val = keras.utils.to_categorical(labels_val, 10) return train_data, test_data, val_data, labels_train, labels_test, labels_val ''' if __name__ == '__main__': cdcgan = CDCGAN() cdcgan.train(epochs=4000, batch_size=32) '''