aboutsummaryrefslogtreecommitdiff
path: root/models.py
blob: 5a28b274cf07dd819f6691c8731b75af6c53f96a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# models.py
# EE4 Computer vision coursework: Models for GAN coursework 
from keras.models import Model, Sequential
from keras.layers import *

def get_generator:
  generator = Sequential([
          Dense(128*7*7, input_dim=100, activation=LeakyReLU(0.2)),
          BatchNormalization(),
          Reshape((7,7,128)),
          UpSampling2D(),
          Convolution2D(64, 5, 5, border_mode='same', activation=LeakyReLU(0.2)),
          BatchNormalization(),
          UpSampling2D(),
          Convolution2D(1, 5, 5, border_mode='same', activation='tanh')
      ])
  
  discriminator = Sequential([
          Convolution2D(64, 5, 5, subsample=(2,2), input_shape=(28,28,1), border_mode='same', activation=LeakyReLU(0.2)),
          Dropout(0.3),
          Convolution2D(128, 5, 5, subsample=(2,2), border_mode='same', activation=LeakyReLU(0.2)),
          Dropout(0.3),
          Flatten(),
          Dense(1, activation='sigmoid')
      ])
  return generator

def get_discriminator:
  discriminator = Sequential([
          Convolution2D(64, 5, 5, subsample=(2,2), input_shape=(28,28,1), border_mode='same', activation=LeakyReLU(0.2)),
          Dropout(0.3),
          Convolution2D(128, 5, 5, subsample=(2,2), border_mode='same', activation=LeakyReLU(0.2)),
          Dropout(0.3),
          Flatten(),
          Dense(1, activation='sigmoid')
      ])
  return discriminator