from __future__ import print_function, division import tensorflow as keras import tensorflow as tf from tensorflow.keras.datasets import mnist from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D from tensorflow.keras.layers import LeakyReLU from tensorflow.keras.layers import UpSampling2D, Conv2D from tensorflow.keras.models import Sequential, Model from tensorflow.keras.optimizers import Adam from lib.virtual_batch import VirtualBatchNormalization import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec from tqdm import tqdm import sys import numpy as np class DCGAN(): def __init__(self, conv_layers = 1, virtual_batch_normalization=False): # Input shape self.img_rows = 28 self.img_cols = 28 self.channels = 1 self.img_shape = (self.img_rows, self.img_cols, self.channels) self.latent_dim = 100 self.conv_layers = conv_layers self.virtual_batch_normalization = virtual_batch_normalization optimizer = Adam(0.002, 0.5) # Build and compile the discriminator self.discriminator = self.build_discriminator() self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) # Build the generator self.generator = self.build_generator() # The generator takes noise as input and generates imgs z = Input(shape=(self.latent_dim,)) img = self.generator(z) # For the combined model we will only train the generator self.discriminator.trainable = False # The discriminator takes generated images as input and determines validity valid = self.discriminator(img) # The combined model (stacked generator and discriminator) # Trains the generator to fool the discriminator self.combined = Model(z, valid) self.combined.compile(loss='binary_crossentropy', optimizer=optimizer) def build_generator(self): model = Sequential() model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim)) model.add(Reshape((7, 7, 128))) model.add(UpSampling2D()) for i in range(self.conv_layers): model.add(Conv2D(128, kernel_size=3, padding="same")) if self.virtual_batch_normalization: model.add(VirtualBatchNormalization()) else: model.add(BatchNormalization()) model.add(Activation("relu")) model.add(UpSampling2D()) for i in range(self.conv_layers): model.add(Conv2D(64, kernel_size=3, padding="same")) if self.virtual_batch_normalization: model.add(VirtualBatchNormalization()) else: model.add(BatchNormalization()) model.add(Activation("relu")) model.add(Conv2D(self.channels, kernel_size=3, padding="same")) model.add(Activation("tanh")) #model.summary() noise = Input(shape=(self.latent_dim,)) img = model(noise) return Model(noise, img) def build_discriminator(self): model = Sequential() model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) model.add(ZeroPadding2D(padding=((0,1),(0,1)))) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(128, kernel_size=3, strides=2, padding="same")) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(256, kernel_size=3, strides=1, padding="same")) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(1, activation='sigmoid')) #model.summary() img = Input(shape=self.img_shape) validity = model(img) return Model(img, validity) def train(self, epochs, batch_size=128, save_interval=50, VBN=False): # Load the dataset (X_train, _), (_, _) = mnist.load_data() # Rescale -1 to 1 X_train = X_train / 127.5 - 1. X_train = np.expand_dims(X_train, axis=3) # Adversarial ground truths valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) xaxis = np.arange(epochs) loss = np.zeros((2,epochs)) for epoch in tqdm(range(epochs)): # --------------------- # Train Discriminator # --------------------- # Select a random half of images idx = np.random.randint(0, X_train.shape[0], batch_size) imgs = X_train[idx] tf.keras.backend.get_session().run(tf.global_variables_initializer()) # Sample noise and generate a batch of new images noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) gen_imgs = self.generator.predict(noise) # Train the discriminator (real classified as ones and generated as zeros) d_loss_real = self.discriminator.train_on_batch(imgs, valid) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # Train Generator # --------------------- # Train the generator (wants discriminator to mistake images as real) g_loss = self.combined.train_on_batch(noise, valid) # Plot the progress #print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)) loss[0][epoch] = d_loss[0] loss[1][epoch] = g_loss # If at save interval => save generated image samples if epoch % save_interval == 0: self.save_imgs(epoch) plt.plot(xaxis,loss[0]) plt.plot(xaxis,loss[1]) plt.legend(('Discriminator', 'Generator'), loc='best') plt.xlabel('Epoch') plt.ylabel('Binary Crossentropy Loss') def save_imgs(self, epoch): r, c = 10, 10 noise = np.random.normal(0, 1, (r * c, self.latent_dim)) gen_imgs = self.generator.predict(noise) # Rescale images 0 - 1 gen_imgs = 0.5 * gen_imgs + 0.5 fig, axs = plt.subplots(r, c) gs = gridspec.GridSpec(r, c) gs.update(wspace=0.05, hspace=0.05) cnt = 0 for i in range(r): for j in range(c): axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray') axs[i,j].axis('off') cnt += 1 fig.savefig("images/mnist_%d.png" % epoch) plt.close() ''' if __name__ == '__main__': dcgan = DCGAN(virtual_batch_normalization=True) dcgan.train(epochs=4000, batch_size=32, save_interval=50) '''