diff options
| -rwxr-xr-x[-rw-r--r--] | cgan.py | 28 | ||||
| -rw-r--r-- | dcgan.py | 36 | ||||
| -rw-r--r-- | lib/__pycache__/virtual_batch.cpython-37.pyc | bin | 1758 -> 0 bytes | |||
| -rw-r--r-- | lib/__pycache__/virtual_batchnorm_impl.cpython-37.pyc | bin | 8723 -> 0 bytes | |||
| -rw-r--r-- | lib/virtual_batch.py | 39 | ||||
| -rw-r--r-- | lib/virtual_batchnorm_impl.py | 306 | 
6 files changed, 22 insertions, 387 deletions
| @@ -1,23 +1,21 @@  from __future__ import print_function, division  import tensorflow.keras as keras  import tensorflow as tf -from tensorflow.keras.datasets import mnist -from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply -from tensorflow.keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D -from tensorflow.keras.layers import LeakyReLU -from tensorflow.keras.layers import UpSampling2D, Conv2D -from tensorflow.keras.models import Sequential, Model -from tensorflow.keras.optimizers import Adam +from keras.datasets import mnist +from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply +from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D +from keras.layers.advanced_activations import LeakyReLU +from keras.layers.convolutional import UpSampling2D, Conv2D +from keras.models import Sequential, Model +from keras.optimizers import Adam  import matplotlib.pyplot as plt  from IPython.display import clear_output  from tqdm import tqdm -from lib.virtual_batch import VirtualBatchNormalization -  import numpy as np  class CGAN(): -    def __init__(self, dense_layers = 3, virtual_batch_normalization=False): +    def __init__(self, dense_layers = 3):          # Input shape          self.img_rows = 28          self.img_cols = 28 @@ -26,7 +24,6 @@ class CGAN():          self.num_classes = 10          self.latent_dim = 100          self.dense_layers = dense_layers -        self.virtual_batch_normalization = virtual_batch_normalization          optimizer = Adam(0.0002, 0.5) @@ -66,10 +63,7 @@ class CGAN():              output_size = 2**(8+i)              model.add(Dense(output_size, input_dim=self.latent_dim))              model.add(LeakyReLU(alpha=0.2)) -            if self.virtual_batch_normalization: -                model.add(VirtualBatchNormalization(momentum=0.8)) -            else: -                model.add(BatchNormalization(momentum=0.8)) +            model.add(BatchNormalization(momentum=0.8))          model.add(Dense(np.prod(self.img_shape), activation='tanh'))          model.add(Reshape(self.img_shape)) @@ -142,7 +136,6 @@ class CGAN():              # Sample noise as generator input              noise = np.random.normal(0, 1, (batch_size, 100)) -            tf.keras.backend.get_session().run(tf.global_variables_initializer())               # Generate a half batch of new images              gen_imgs = self.generator.predict([noise, labels]) @@ -224,9 +217,10 @@ class CGAN():        return train_data, test_data, val_data, labels_train, labels_test, labels_val +  '''  if __name__ == '__main__': -    cgan = CGAN(dense_layers=1, virtual_batch_normalization=True) +    cgan = CGAN(dense_layers=1)      cgan.train(epochs=7000, batch_size=32, sample_interval=200)      train, test, tr_labels, te_labels = cgan.generate_data()      print(train.shape, test.shape) @@ -1,16 +1,11 @@  from __future__ import print_function, division -import tensorflow as keras - -import tensorflow as tf -from tensorflow.keras.datasets import mnist -from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout -from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D -from tensorflow.keras.layers import LeakyReLU -from tensorflow.keras.layers import UpSampling2D, Conv2D -from tensorflow.keras.models import Sequential, Model -from tensorflow.keras.optimizers import Adam - -from lib.virtual_batch import VirtualBatchNormalization +from keras.datasets import mnist +from keras.layers import Input, Dense, Reshape, Flatten, Dropout +from keras.layers import BatchNormalization, Activation, ZeroPadding2D +from keras.layers.advanced_activations import LeakyReLU +from keras.layers.convolutional import UpSampling2D, Conv2D +from keras.models import Sequential, Model +from keras.optimizers import Adam  import matplotlib.pyplot as plt  import matplotlib.gridspec as gridspec @@ -22,7 +17,7 @@ import sys  import numpy as np  class DCGAN(): -    def __init__(self, conv_layers = 1, virtual_batch_normalization=False): +    def __init__(self, conv_layers = 1):          # Input shape          self.img_rows = 28          self.img_cols = 28 @@ -30,7 +25,6 @@ class DCGAN():          self.img_shape = (self.img_rows, self.img_cols, self.channels)          self.latent_dim = 100          self.conv_layers = conv_layers -        self.virtual_batch_normalization = virtual_batch_normalization          optimizer = Adam(0.002, 0.5) @@ -68,21 +62,14 @@ class DCGAN():          for i in range(self.conv_layers):              model.add(Conv2D(128, kernel_size=3, padding="same")) -            if self.virtual_batch_normalization: -                model.add(VirtualBatchNormalization()) -            else: -                model.add(BatchNormalization()) +            model.add(BatchNormalization())              model.add(Activation("relu"))          model.add(UpSampling2D())          for i in range(self.conv_layers):              model.add(Conv2D(64, kernel_size=3, padding="same")) -            if self.virtual_batch_normalization: -                model.add(VirtualBatchNormalization()) -            else: -                model.add(BatchNormalization()) - +            model.add(BatchNormalization())              model.add(Activation("relu"))          model.add(Conv2D(self.channels, kernel_size=3, padding="same")) @@ -151,7 +138,6 @@ class DCGAN():              idx = np.random.randint(0, X_train.shape[0], batch_size)              imgs = X_train[idx] -            tf.keras.backend.get_session().run(tf.global_variables_initializer())               # Sample noise and generate a batch of new images              noise = np.random.normal(0, 1, (batch_size, self.latent_dim))              gen_imgs = self.generator.predict(noise) @@ -203,6 +189,6 @@ class DCGAN():  '''  if __name__ == '__main__': -    dcgan = DCGAN(virtual_batch_normalization=True) +    dcgan = DCGAN()      dcgan.train(epochs=4000, batch_size=32, save_interval=50)  ''' diff --git a/lib/__pycache__/virtual_batch.cpython-37.pyc b/lib/__pycache__/virtual_batch.cpython-37.pycBinary files differ deleted file mode 100644 index 1ca89c1..0000000 --- a/lib/__pycache__/virtual_batch.cpython-37.pyc +++ /dev/null diff --git a/lib/__pycache__/virtual_batchnorm_impl.cpython-37.pyc b/lib/__pycache__/virtual_batchnorm_impl.cpython-37.pycBinary files differ deleted file mode 100644 index 1d41d7f..0000000 --- a/lib/__pycache__/virtual_batchnorm_impl.cpython-37.pyc +++ /dev/null diff --git a/lib/virtual_batch.py b/lib/virtual_batch.py deleted file mode 100644 index dab0419..0000000 --- a/lib/virtual_batch.py +++ /dev/null @@ -1,39 +0,0 @@ -import tensorflow as tf -from tensorflow.keras import backend as K -from tensorflow.keras.layers import Layer -from lib.virtual_batchnorm_impl import VBN -from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras.engine.base_layer import InputSpec -from tensorflow.python.keras import initializers - -class VirtualBatchNormalization(Layer): -    def __init__(self, -               momentum=0.99, -               center=True, -               scale=True, -               beta_initializer='zeros', -               gamma_initializer='ones', -               beta_regularizer=None, -               gamma_regularizer=None, -               **kwargs): - -        self.beta_initializer = initializers.get(beta_initializer) -        self.gamma_initializer = initializers.get(gamma_initializer) -  -        super(VirtualBatchNormalization, self).__init__(**kwargs) - -    def build(self, input_shape): -        input_shape = tensor_shape.TensorShape(input_shape) -        if not input_shape.ndims: -          raise ValueError('Input has undefined rank:', input_shape) -        ndims = len(input_shape) -        self.input_spec = InputSpec(ndim=ndims) -        #super(VirtualBatchNormalization, self).build(input_shape)  # Be sure to call this at the end - -    def call(self, x): -        outputs = VBN(x, gamma_initializer=self.gamma_initializer, beta_initializer=self.beta_initializer)(x) -        outputs.set_shape(x.get_shape()) -        return outputs - -    def compute_output_shape(self, input_shape): -        return input_shape diff --git a/lib/virtual_batchnorm_impl.py b/lib/virtual_batchnorm_impl.py deleted file mode 100644 index 650eab9..0000000 --- a/lib/virtual_batchnorm_impl.py +++ /dev/null @@ -1,306 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -#     http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Virtual batch normalization. - -This technique was first introduced in `Improved Techniques for Training GANs` -(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch -normalization on a minibatch, it fixes a reference subset of the data to use for -calculating normalization statistics. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import variable_scope - -__all__ = [ -    'VBN', -] - - -def _static_or_dynamic_batch_size(tensor, batch_axis): -  """Returns the static or dynamic batch size.""" -  batch_size = array_ops.shape(tensor)[batch_axis] -  static_batch_size = tensor_util.constant_value(batch_size) -  return static_batch_size or batch_size - - -def _statistics(x, axes): -  """Calculate the mean and mean square of `x`. - -  Modified from the implementation of `tf.nn.moments`. - -  Args: -    x: A `Tensor`. -    axes: Array of ints.  Axes along which to compute mean and -      variance. - -  Returns: -    Two `Tensor` objects: `mean` and `square mean`. -  """ -  # The dynamic range of fp16 is too limited to support the collection of -  # sufficient statistics. As a workaround we simply perform the operations -  # on 32-bit floats before converting the mean and variance back to fp16 -  y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x - -  # Compute true mean while keeping the dims for proper broadcasting. -  shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True)) - -  shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True) -  mean = shifted_mean + shift -  mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True) - -  mean = array_ops.squeeze(mean, axes) -  mean_squared = array_ops.squeeze(mean_squared, axes) -  if x.dtype == dtypes.float16: -    return (math_ops.cast(mean, dtypes.float16), -            math_ops.cast(mean_squared, dtypes.float16)) -  else: -    return (mean, mean_squared) - - -def _validate_init_input_and_get_axis(reference_batch, axis): -  """Validate input and return the used axis value.""" -  if reference_batch.shape.ndims is None: -    raise ValueError('`reference_batch` has unknown dimensions.') - -  ndims = reference_batch.shape.ndims -  if axis < 0: -    used_axis = ndims + axis -  else: -    used_axis = axis -  if used_axis < 0 or used_axis >= ndims: -    raise ValueError('Value of `axis` argument ' + str(used_axis) + -                     ' is out of range for input with rank ' + str(ndims)) -  return used_axis - - -def _validate_call_input(tensor_list, batch_dim): -  """Verifies that tensor shapes are compatible, except for `batch_dim`.""" -  def _get_shape(tensor): -    shape = tensor.shape.as_list() -    del shape[batch_dim] -    return shape -  base_shape = tensor_shape.TensorShape(_get_shape(tensor_list[0])) -  for tensor in tensor_list: -    base_shape.assert_is_compatible_with(_get_shape(tensor)) - - -class VBN(object): -  """A class to perform virtual batch normalization. - -  This technique was first introduced in `Improved Techniques for Training GANs` -  (Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch -  normalization on a minibatch, it fixes a reference subset of the data to use -  for calculating normalization statistics. - -  To do this, we calculate the reference batch mean and mean square, and modify -  those statistics for each example. We use mean square instead of variance, -  since it is linear. - -  Note that if `center` or `scale` variables are created, they are shared -  between all calls to this object. - -  The `__init__` API is intended to mimic `tf.layers.batch_normalization` as -  closely as possible. -  """ - -  def __init__(self, -               reference_batch, -               axis=-1, -               epsilon=1e-3, -               center=True, -               scale=True, -               beta_initializer=init_ops.zeros_initializer(), -               gamma_initializer=init_ops.ones_initializer(), -               beta_regularizer=None, -               gamma_regularizer=None, -               trainable=True, -               name=None, -               batch_axis=0): -    """Initialize virtual batch normalization object. - -    We precompute the 'mean' and 'mean squared' of the reference batch, so that -    `__call__` is efficient. This means that the axis must be supplied when the -    object is created, not when it is called. - -    We precompute 'square mean' instead of 'variance', because the square mean -    can be easily adjusted on a per-example basis. - -    Args: -      reference_batch: A minibatch tensors. This will form the reference data -        from which the normalization statistics are calculated. See -        https://arxiv.org/abs/1606.03498 for more details. -      axis: Integer, the axis that should be normalized (typically the features -        axis). For instance, after a `Convolution2D` layer with -        `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. -      epsilon: Small float added to variance to avoid dividing by zero. -      center: If True, add offset of `beta` to normalized tensor. If False, -        `beta` is ignored. -      scale: If True, multiply by `gamma`. If False, `gamma` is -        not used. When the next layer is linear (also e.g. `nn.relu`), this can -        be disabled since the scaling can be done by the next layer. -      beta_initializer: Initializer for the beta weight. -      gamma_initializer: Initializer for the gamma weight. -      beta_regularizer: Optional regularizer for the beta weight. -      gamma_regularizer: Optional regularizer for the gamma weight. -      trainable: Boolean, if `True` also add variables to the graph collection -        `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). -      name: String, the name of the ops. -      batch_axis: The axis of the batch dimension. This dimension is treated -        differently in `virtual batch normalization` vs `batch normalization`. - -    Raises: -      ValueError: If `reference_batch` has unknown dimensions at graph -        construction. -      ValueError: If `batch_axis` is the same as `axis`. -    """ -    axis = _validate_init_input_and_get_axis(reference_batch, axis) -    self._epsilon = epsilon -    self._beta = 0 -    self._gamma = 1 -    self._batch_axis = _validate_init_input_and_get_axis( -        reference_batch, batch_axis) - -    if axis == self._batch_axis: -      raise ValueError('`axis` and `batch_axis` cannot be the same.') - -    with variable_scope.variable_scope(name, 'VBN', -                                       values=[reference_batch]) as self._vs: -      self._reference_batch = reference_batch - -      # Calculate important shapes: -      #  1) Reduction axes for the reference batch -      #  2) Broadcast shape, if necessary -      #  3) Reduction axes for the virtual batchnormed batch -      #  4) Shape for optional parameters -      input_shape = self._reference_batch.shape -      ndims = input_shape.ndims -      reduction_axes = list(range(ndims)) -      del reduction_axes[axis] - -      self._broadcast_shape = [1] * len(input_shape) -      self._broadcast_shape[axis] = input_shape[axis].value - -      self._example_reduction_axes = list(range(ndims)) -      del self._example_reduction_axes[max(axis, self._batch_axis)] -      del self._example_reduction_axes[min(axis, self._batch_axis)] - -      params_shape = self._reference_batch.shape[axis] - -      # Determines whether broadcasting is needed. This is slightly different -      # than in the `nn.batch_normalization` case, due to `batch_dim`. -      self._needs_broadcasting = ( -          sorted(self._example_reduction_axes) != list(range(ndims))[:-2]) - -      # Calculate the sufficient statistics for the reference batch in a way -      # that can be easily modified by additional examples. -      self._ref_mean, self._ref_mean_squares = _statistics( -          self._reference_batch, reduction_axes) -      self._ref_variance = (self._ref_mean_squares - -                            math_ops.square(self._ref_mean)) - -      # Virtual batch normalization uses a weighted average between example -      # statistics and the reference batch statistics. -      ref_batch_size = _static_or_dynamic_batch_size( -          self._reference_batch, self._batch_axis) -      self._example_weight = 1. / (math_ops.to_float(ref_batch_size) + 1.) -      self._ref_weight = 1. - self._example_weight - -      # Make the variables, if necessary. -      if center: -        self._beta = variable_scope.get_variable( -            name='beta', -            shape=(params_shape,), -            initializer=beta_initializer, -            regularizer=beta_regularizer, -            trainable=trainable) -      if scale: -        self._gamma = variable_scope.get_variable( -            name='gamma', -            shape=(params_shape,), -            initializer=gamma_initializer, -            regularizer=gamma_regularizer, -            trainable=trainable) - -  def _virtual_statistics(self, inputs, reduction_axes): -    """Compute the statistics needed for virtual batch normalization.""" -    cur_mean, cur_mean_sq = _statistics(inputs, reduction_axes) -    vb_mean = (self._example_weight * cur_mean + -               self._ref_weight * self._ref_mean) -    vb_mean_sq = (self._example_weight * cur_mean_sq + -                  self._ref_weight * self._ref_mean_squares) -    return (vb_mean, vb_mean_sq) - -  def _broadcast(self, v, broadcast_shape=None): -    # The exact broadcast shape depends on the current batch, not the reference -    # batch, unless we're calculating the batch normalization of the reference -    # batch. -    b_shape = broadcast_shape or self._broadcast_shape -    if self._needs_broadcasting and v is not None: -      return array_ops.reshape(v, b_shape) -    return v - -  def reference_batch_normalization(self): -    """Return the reference batch, but batch normalized.""" -    with ops.name_scope(self._vs.name): -      return nn.batch_normalization(self._reference_batch, -                                    self._broadcast(self._ref_mean), -                                    self._broadcast(self._ref_variance), -                                    self._broadcast(self._beta), -                                    self._broadcast(self._gamma), -                                    self._epsilon) - -  def __call__(self, inputs): -    """Run virtual batch normalization on inputs. - -    Args: -      inputs: Tensor input. - -    Returns: -       A virtual batch normalized version of `inputs`. - -    Raises: -       ValueError: If `inputs` shape isn't compatible with the reference batch. -    """ -    _validate_call_input([inputs, self._reference_batch], self._batch_axis) - -    with ops.name_scope(self._vs.name, values=[inputs, self._reference_batch]): -      # Calculate the statistics on the current input on a per-example basis. -      vb_mean, vb_mean_sq = self._virtual_statistics( -          inputs, self._example_reduction_axes) -      vb_variance = vb_mean_sq - math_ops.square(vb_mean) - -      # The exact broadcast shape of the input statistic Tensors depends on the -      # current batch, not the reference batch. The parameter broadcast shape -      # is independent of the shape of the input statistic Tensor dimensions. -      b_shape = self._broadcast_shape[:]  # deep copy -      b_shape[self._batch_axis] = _static_or_dynamic_batch_size( -          inputs, self._batch_axis) -      return nn.batch_normalization( -          inputs, -          self._broadcast(vb_mean, b_shape), -          self._broadcast(vb_variance, b_shape), -          self._broadcast(self._beta, self._broadcast_shape), -          self._broadcast(self._gamma, self._broadcast_shape), -          self._epsilon) | 
