from __future__ import print_function, division import tensorflow.keras as keras import tensorflow as tf from tensorflow.keras.datasets import mnist from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply from tensorflow.keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D from tensorflow.keras.layers import LeakyReLU from tensorflow.keras.layers import UpSampling2D, Conv2D from tensorflow.keras.models import Sequential, Model from tensorflow.keras.optimizers import Adam import matplotlib.pyplot as plt from IPython.display import clear_output from tqdm import tqdm from lib.virtual_batch import VirtualBatchNormalization import numpy as np class CGAN(): def __init__(self, dense_layers = 3, virtual_batch_normalization=False): # Input shape self.img_rows = 28 self.img_cols = 28 self.channels = 1 self.img_shape = (self.img_rows, self.img_cols, self.channels) self.num_classes = 10 self.latent_dim = 100 self.dense_layers = dense_layers self.virtual_batch_normalization = virtual_batch_normalization optimizer = Adam(0.0002, 0.5) # Build and compile the discriminator self.discriminator = self.build_discriminator() self.discriminator.compile(loss=['binary_crossentropy'], optimizer=optimizer, metrics=['accuracy']) # Build the generator self.generator = self.build_generator() # The generator takes noise and the target label as input # and generates the corresponding digit of that label noise = Input(shape=(self.latent_dim,)) label = Input(shape=(1,)) img = self.generator([noise, label]) # For the combined model we will only train the generator self.discriminator.trainable = False # The discriminator takes generated image as input and determines validity # and the label of that image valid = self.discriminator([img, label]) # The combined model (stacked generator and discriminator) # Trains generator to fool discriminator self.combined = Model([noise, label], valid) self.combined.compile(loss=['binary_crossentropy'], optimizer=optimizer) def build_generator(self): model = Sequential() for i in range(self.dense_layers): output_size = 2**(8+i) model.add(Dense(output_size, input_dim=self.latent_dim)) model.add(LeakyReLU(alpha=0.2)) if self.virtual_batch_normalization: model.add(VirtualBatchNormalization(momentum=0.8)) else: model.add(BatchNormalization(momentum=0.8)) model.add(Dense(np.prod(self.img_shape), activation='tanh')) model.add(Reshape(self.img_shape)) #model.summary() noise = Input(shape=(self.latent_dim,)) label = Input(shape=(1,), dtype='int32') label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label)) model_input = multiply([noise, label_embedding]) img = model(model_input) return Model([noise, label], img) def build_discriminator(self): model = Sequential() model.add(Dense(512, input_dim=np.prod(self.img_shape))) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.4)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.4)) model.add(Dense(1, activation='sigmoid')) #model.summary() img = Input(shape=self.img_shape) label = Input(shape=(1,), dtype='int32') label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label)) flat_img = Flatten()(img) model_input = multiply([flat_img, label_embedding]) validity = model(model_input) return Model([img, label], validity) def train(self, epochs, batch_size=128, sample_interval=50, graph=False): # Load the dataset (X_train, y_train), (_, _) = mnist.load_data() # Configure input X_train = (X_train.astype(np.float32) - 127.5) / 127.5 X_train = np.expand_dims(X_train, axis=3) y_train = y_train.reshape(-1, 1) # Adversarial ground truths valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) xaxis = np.arange(epochs) loss = np.zeros((2,epochs)) for epoch in tqdm(range(epochs)): # --------------------- # Train Discriminator # --------------------- # Select a random half batch of images idx = np.random.randint(0, X_train.shape[0], batch_size) imgs, labels = X_train[idx], y_train[idx] # Sample noise as generator input noise = np.random.normal(0, 1, (batch_size, 100)) tf.keras.backend.get_session().run(tf.global_variables_initializer()) # Generate a half batch of new images gen_imgs = self.generator.predict([noise, labels]) # Train the discriminator d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid) d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # Train Generator # --------------------- # Condition on labels sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1) # Train the generator g_loss = self.combined.train_on_batch([noise, sampled_labels], valid) # Plot the progress #print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)) loss[0][epoch] = d_loss[0] loss[1][epoch] = g_loss # If at save interval => save generated image samples if epoch % sample_interval == 0: self.sample_images(epoch) if graph: plt.plot(xaxis,loss[0]) plt.plot(xaxis,loss[1]) plt.legend(('Discriminator', 'Generator'), loc='best') plt.xlabel('Epoch') plt.ylabel('Binary Crossentropy Loss') def sample_images(self, epoch): r, c = 2, 5 noise = np.random.normal(0, 1, (r * c, 100)) sampled_labels = np.arange(0, 10).reshape(-1, 1) #using dummy_labels would just print zeros to help identify image quality #dummy_labels = np.zeros(32).reshape(-1, 1) gen_imgs = self.generator.predict([noise, sampled_labels]) # Rescale images 0 - 1 gen_imgs = 0.5 * gen_imgs + 0.5 fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray') axs[i,j].set_title("Digit: %d" % sampled_labels[cnt]) axs[i,j].axis('off') cnt += 1 fig.savefig("images/%d.png" % epoch) plt.close() def generate_data(self): noise_train = np.random.normal(0, 1, (55000, 100)) noise_test = np.random.normal(0, 1, (10000, 100)) noise_val = np.random.normal(0, 1, (5000, 100)) labels_train = np.zeros(55000).reshape(-1, 1) labels_test = np.zeros(10000).reshape(-1, 1) labels_val = np.zeros(5000).reshape(-1, 1) for i in range(10): labels_train[i*5500:] = i labels_test[i*1000:] = i labels_val[i*500:] = i train_data = self.generator.predict([noise_train, labels_train]) test_data = self.generator.predict([noise_test, labels_test]) val_data = self.generator.predict([noise_val, labels_val]) labels_train = keras.utils.to_categorical(labels_train, 10) labels_test = keras.utils.to_categorical(labels_test, 10) labels_val = keras.utils.to_categorical(labels_val, 10) return train_data, test_data, val_data, labels_train, labels_test, labels_val ''' if __name__ == '__main__': cgan = CGAN(dense_layers=1, virtual_batch_normalization=True) cgan.train(epochs=7000, batch_size=32, sample_interval=200) train, test, tr_labels, te_labels = cgan.generate_data() print(train.shape, test.shape) '''