diff options
Diffstat (limited to 'dcgan.py')
-rw-r--r-- | dcgan.py | 18 |
1 files changed, 11 insertions, 7 deletions
@@ -18,7 +18,7 @@ import sys import numpy as np class DCGAN(): - def __init__(self, conv_layers = 1): + def __init__(self, conv_layers = 1, dropout = 0.25): # Input shape self.img_rows = 28 self.img_cols = 28 @@ -26,6 +26,7 @@ class DCGAN(): self.img_shape = (self.img_rows, self.img_cols, self.channels) self.latent_dim = 100 self.conv_layers = conv_layers + self.dropout = dropout optimizer = Adam(0.002, 0.5) @@ -89,20 +90,20 @@ class DCGAN(): model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same")) model.add(LeakyReLU(alpha=0.2)) - model.add(Dropout(0.25)) + model.add(Dropout(self.dropout)) model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) model.add(ZeroPadding2D(padding=((0,1),(0,1)))) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) - model.add(Dropout(0.25)) + model.add(Dropout(self.dropout)) model.add(Conv2D(128, kernel_size=3, strides=2, padding="same")) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) - model.add(Dropout(0.25)) + model.add(Dropout(self.dropout)) model.add(Conv2D(256, kernel_size=3, strides=1, padding="same")) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) - model.add(Dropout(0.25)) + model.add(Dropout(self.dropout)) model.add(Flatten()) model.add(Dense(1, activation='sigmoid')) @@ -113,7 +114,7 @@ class DCGAN(): return Model(img, validity) - def train(self, epochs, batch_size=128, save_interval=50, VBN=False): + def train(self, epochs, batch_size=128, save_interval=50, VBN=False, gdstep=1): # Load the dataset (X_train, _), (_, _) = mnist.load_data() @@ -153,7 +154,10 @@ class DCGAN(): # --------------------- # Train the generator (wants discriminator to mistake images as real) - g_loss = self.combined.train_on_batch(noise, valid) + if epoch % gdstep == 0: + g_loss = self.combined.train_on_batch(noise, valid) + else: + g_loss = 0 # Plot the progress #print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)) |