from __future__ import print_function, division import tensorflow as keras from tensorflow.keras.datasets import mnist from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D from tensorflow.keras.layers.advanced_activations import LeakyReLU from tensorflow.keras.layers.convolutional import UpSampling2D, Conv2D from tensorflow.keras.models import Sequential, Model from tensorflow.keras.optimizers import Adam from lib/virtual_batch import VirtualBatchNormalization import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec from tqdm import tqdm import sys import numpy as np class DCGAN(): def __init__(self, conv_layers = 1, virtual_batch_normalization=False): # Input shape self.img_rows = 28 self.img_cols = 28 self.channels = 1 self.img_shape = (self.img_rows, self.img_cols, self.channels) self.latent_dim = 100 self.conv_layers = conv_layers self.virtual_batch_normalization = virtual_batch_normalization optimizer = Adam(0.002, 0.5) # Build and compile the discriminator self.discriminator = self.build_discriminator() self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) # Build the generator self.generator = self.build_generator() # The generator takes noise as input and generates imgs z = Input(shape=(self.latent_dim,)) img = self.generator(z) # For the combined model we will only train the generator self.discriminator.trainable = False # The discriminator takes generated images as input and determines validity valid = self.discriminator(img) # The combined model (stacked generator and discriminator) # Trains the generator to fool the discriminator self.combined = Model(z, valid) self.combined.compile(loss='binary_crossentropy', optimizer=optimizer) def build_generator(self): model = Sequential() model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim)) model.add(Reshape((7, 7, 128))) model.add(UpSampling2D()) for i in range(self.conv_layers): model.add(Conv2D(128, kernel_size=3, padding="same")) if self.virtual_batch_normalization: model.add(VirtualBatchNormalization()) else: model.add(BatchNormalization()) model.add(Activation("relu")) model.add(UpSampling2D()) for i in range(self.conv_layers): model.add(Conv2D(64, kernel_size=3, padding="same")) if self.virtual_batch_normalization: model.add(VirtualBatchNormalization()) else: model.add(BatchNormalization()) model.add(Activation("relu")) model.add(Conv2D(self.channels, kernel_size=3, padding="same")) model.add(Activation("tanh")) #model.summary() noise = Input(shape=(self.latent_dim,)) img = model(noise) return Model(noise, img) def build_discriminator(self): model = Sequential() model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) model.add(ZeroPadding2D(padding=((0,1),(0,1)))) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(128, kernel_size=3, strides=2, padding="same")) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(256, kernel_size=3, strides=1, padding="same")) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(1, activation='sigmoid')) #model.summary() img = Input(shape=self.img_shape) validity = model(img) return Model(img, validity) def train(self, epochs, batch_size=128, save_interval=50, VBN=False): # Load the dataset (X_train, _), (_, _) = mnist.load_data() # Rescale -1 to 1 X_train = X_train / 127.5 - 1. X_train = np.expand_dims(X_train, axis=3) # Adversarial ground truths valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) xaxis = np.arange(epochs) loss = np.zeros((2,epochs)) for epoch in tqdm(range(epochs)): # --------------------- # Train Discriminator # --------------------- # Select a random half of images idx = np.random.randint(0, X_train.shape[0], batch_size) imgs = X_train[idx] tf.keras.backend.get_session().run(tf.global_variables_initializer()) # Sample noise and generate a batch of new images noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) gen_imgs = self.generator.predict(noise) # Train the discriminator (real classified as ones and generated as zeros) d_loss_real = self.discriminator.train_on_batch(imgs, valid) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # Train Generator # --------------------- # Train the generator (wants discriminator to mistake images as real) g_loss = self.combined.train_on_batch(noise, valid) # Plot the progress #print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)) loss[0][epoch] = d_loss[0] loss[1][epoch] = g_loss # If at save interval => save generated image samples if epoch % save_interval == 0: self.save_imgs(epoch) plt.plot(xaxis,loss[0]) plt.plot(xaxis,loss[1]) plt.legend(('Discriminator', 'Generator'), loc='best') plt.xlabel('Epoch') plt.ylabel('Binary Crossentropy Loss') def save_imgs(self, epoch): r, c = 10, 10 noise = np.random.normal(0, 1, (r * c, self.latent_dim)) gen_imgs = self.generator.predict(noise) # Rescale images 0 - 1 gen_imgs = 0.5 * gen_imgs + 0.5 fig, axs = plt.subplots(r, c) gs = gridspec.GridSpec(r, c) gs.update(wspace=0.05, hspace=0.05) cnt = 0 for i in range(r): for j in range(c): axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray') axs[i,j].axis('off') cnt += 1 fig.savefig("images/mnist_%d.png" % epoch) plt.close() ''' if __name__ == '__main__': dcgan = DCGAN() dcgan.train(epochs=4000, batch_size=32, save_interval=50) '''