From bc501637cdb329db681b439563cdae418f3fa897 Mon Sep 17 00:00:00 2001 From: Vasil Zlatanov Date: Wed, 6 Mar 2019 20:39:00 +0000 Subject: Revert "Add virtual_batch support" This reverts commit 740e1b0c6a02a7bec20008758373f0dd80baade4. --- cgan.py | 28 +- dcgan.py | 36 +-- lib/__pycache__/virtual_batch.cpython-37.pyc | Bin 1758 -> 0 bytes .../virtual_batchnorm_impl.cpython-37.pyc | Bin 8723 -> 0 bytes lib/virtual_batch.py | 39 --- lib/virtual_batchnorm_impl.py | 306 --------------------- 6 files changed, 22 insertions(+), 387 deletions(-) mode change 100644 => 100755 cgan.py delete mode 100644 lib/__pycache__/virtual_batch.cpython-37.pyc delete mode 100644 lib/__pycache__/virtual_batchnorm_impl.cpython-37.pyc delete mode 100644 lib/virtual_batch.py delete mode 100644 lib/virtual_batchnorm_impl.py diff --git a/cgan.py b/cgan.py old mode 100644 new mode 100755 index 45b9bb9..6406244 --- a/cgan.py +++ b/cgan.py @@ -1,23 +1,21 @@ from __future__ import print_function, division import tensorflow.keras as keras import tensorflow as tf -from tensorflow.keras.datasets import mnist -from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply -from tensorflow.keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D -from tensorflow.keras.layers import LeakyReLU -from tensorflow.keras.layers import UpSampling2D, Conv2D -from tensorflow.keras.models import Sequential, Model -from tensorflow.keras.optimizers import Adam +from keras.datasets import mnist +from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply +from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D +from keras.layers.advanced_activations import LeakyReLU +from keras.layers.convolutional import UpSampling2D, Conv2D +from keras.models import Sequential, Model +from keras.optimizers import Adam import matplotlib.pyplot as plt from IPython.display import clear_output from tqdm import tqdm -from lib.virtual_batch import VirtualBatchNormalization - import numpy as np class CGAN(): - def __init__(self, dense_layers = 3, virtual_batch_normalization=False): + def __init__(self, dense_layers = 3): # Input shape self.img_rows = 28 self.img_cols = 28 @@ -26,7 +24,6 @@ class CGAN(): self.num_classes = 10 self.latent_dim = 100 self.dense_layers = dense_layers - self.virtual_batch_normalization = virtual_batch_normalization optimizer = Adam(0.0002, 0.5) @@ -66,10 +63,7 @@ class CGAN(): output_size = 2**(8+i) model.add(Dense(output_size, input_dim=self.latent_dim)) model.add(LeakyReLU(alpha=0.2)) - if self.virtual_batch_normalization: - model.add(VirtualBatchNormalization(momentum=0.8)) - else: - model.add(BatchNormalization(momentum=0.8)) + model.add(BatchNormalization(momentum=0.8)) model.add(Dense(np.prod(self.img_shape), activation='tanh')) model.add(Reshape(self.img_shape)) @@ -142,7 +136,6 @@ class CGAN(): # Sample noise as generator input noise = np.random.normal(0, 1, (batch_size, 100)) - tf.keras.backend.get_session().run(tf.global_variables_initializer()) # Generate a half batch of new images gen_imgs = self.generator.predict([noise, labels]) @@ -224,9 +217,10 @@ class CGAN(): return train_data, test_data, val_data, labels_train, labels_test, labels_val + ''' if __name__ == '__main__': - cgan = CGAN(dense_layers=1, virtual_batch_normalization=True) + cgan = CGAN(dense_layers=1) cgan.train(epochs=7000, batch_size=32, sample_interval=200) train, test, tr_labels, te_labels = cgan.generate_data() print(train.shape, test.shape) diff --git a/dcgan.py b/dcgan.py index 21afaac..347f61e 100644 --- a/dcgan.py +++ b/dcgan.py @@ -1,16 +1,11 @@ from __future__ import print_function, division -import tensorflow as keras - -import tensorflow as tf -from tensorflow.keras.datasets import mnist -from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout -from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D -from tensorflow.keras.layers import LeakyReLU -from tensorflow.keras.layers import UpSampling2D, Conv2D -from tensorflow.keras.models import Sequential, Model -from tensorflow.keras.optimizers import Adam - -from lib.virtual_batch import VirtualBatchNormalization +from keras.datasets import mnist +from keras.layers import Input, Dense, Reshape, Flatten, Dropout +from keras.layers import BatchNormalization, Activation, ZeroPadding2D +from keras.layers.advanced_activations import LeakyReLU +from keras.layers.convolutional import UpSampling2D, Conv2D +from keras.models import Sequential, Model +from keras.optimizers import Adam import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec @@ -22,7 +17,7 @@ import sys import numpy as np class DCGAN(): - def __init__(self, conv_layers = 1, virtual_batch_normalization=False): + def __init__(self, conv_layers = 1): # Input shape self.img_rows = 28 self.img_cols = 28 @@ -30,7 +25,6 @@ class DCGAN(): self.img_shape = (self.img_rows, self.img_cols, self.channels) self.latent_dim = 100 self.conv_layers = conv_layers - self.virtual_batch_normalization = virtual_batch_normalization optimizer = Adam(0.002, 0.5) @@ -68,21 +62,14 @@ class DCGAN(): for i in range(self.conv_layers): model.add(Conv2D(128, kernel_size=3, padding="same")) - if self.virtual_batch_normalization: - model.add(VirtualBatchNormalization()) - else: - model.add(BatchNormalization()) + model.add(BatchNormalization()) model.add(Activation("relu")) model.add(UpSampling2D()) for i in range(self.conv_layers): model.add(Conv2D(64, kernel_size=3, padding="same")) - if self.virtual_batch_normalization: - model.add(VirtualBatchNormalization()) - else: - model.add(BatchNormalization()) - + model.add(BatchNormalization()) model.add(Activation("relu")) model.add(Conv2D(self.channels, kernel_size=3, padding="same")) @@ -151,7 +138,6 @@ class DCGAN(): idx = np.random.randint(0, X_train.shape[0], batch_size) imgs = X_train[idx] - tf.keras.backend.get_session().run(tf.global_variables_initializer()) # Sample noise and generate a batch of new images noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) gen_imgs = self.generator.predict(noise) @@ -203,6 +189,6 @@ class DCGAN(): ''' if __name__ == '__main__': - dcgan = DCGAN(virtual_batch_normalization=True) + dcgan = DCGAN() dcgan.train(epochs=4000, batch_size=32, save_interval=50) ''' diff --git a/lib/__pycache__/virtual_batch.cpython-37.pyc b/lib/__pycache__/virtual_batch.cpython-37.pyc deleted file mode 100644 index 1ca89c1..0000000 Binary files a/lib/__pycache__/virtual_batch.cpython-37.pyc and /dev/null differ diff --git a/lib/__pycache__/virtual_batchnorm_impl.cpython-37.pyc b/lib/__pycache__/virtual_batchnorm_impl.cpython-37.pyc deleted file mode 100644 index 1d41d7f..0000000 Binary files a/lib/__pycache__/virtual_batchnorm_impl.cpython-37.pyc and /dev/null differ diff --git a/lib/virtual_batch.py b/lib/virtual_batch.py deleted file mode 100644 index dab0419..0000000 --- a/lib/virtual_batch.py +++ /dev/null @@ -1,39 +0,0 @@ -import tensorflow as tf -from tensorflow.keras import backend as K -from tensorflow.keras.layers import Layer -from lib.virtual_batchnorm_impl import VBN -from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras.engine.base_layer import InputSpec -from tensorflow.python.keras import initializers - -class VirtualBatchNormalization(Layer): - def __init__(self, - momentum=0.99, - center=True, - scale=True, - beta_initializer='zeros', - gamma_initializer='ones', - beta_regularizer=None, - gamma_regularizer=None, - **kwargs): - - self.beta_initializer = initializers.get(beta_initializer) - self.gamma_initializer = initializers.get(gamma_initializer) - - super(VirtualBatchNormalization, self).__init__(**kwargs) - - def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - if not input_shape.ndims: - raise ValueError('Input has undefined rank:', input_shape) - ndims = len(input_shape) - self.input_spec = InputSpec(ndim=ndims) - #super(VirtualBatchNormalization, self).build(input_shape) # Be sure to call this at the end - - def call(self, x): - outputs = VBN(x, gamma_initializer=self.gamma_initializer, beta_initializer=self.beta_initializer)(x) - outputs.set_shape(x.get_shape()) - return outputs - - def compute_output_shape(self, input_shape): - return input_shape diff --git a/lib/virtual_batchnorm_impl.py b/lib/virtual_batchnorm_impl.py deleted file mode 100644 index 650eab9..0000000 --- a/lib/virtual_batchnorm_impl.py +++ /dev/null @@ -1,306 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Virtual batch normalization. - -This technique was first introduced in `Improved Techniques for Training GANs` -(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch -normalization on a minibatch, it fixes a reference subset of the data to use for -calculating normalization statistics. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import variable_scope - -__all__ = [ - 'VBN', -] - - -def _static_or_dynamic_batch_size(tensor, batch_axis): - """Returns the static or dynamic batch size.""" - batch_size = array_ops.shape(tensor)[batch_axis] - static_batch_size = tensor_util.constant_value(batch_size) - return static_batch_size or batch_size - - -def _statistics(x, axes): - """Calculate the mean and mean square of `x`. - - Modified from the implementation of `tf.nn.moments`. - - Args: - x: A `Tensor`. - axes: Array of ints. Axes along which to compute mean and - variance. - - Returns: - Two `Tensor` objects: `mean` and `square mean`. - """ - # The dynamic range of fp16 is too limited to support the collection of - # sufficient statistics. As a workaround we simply perform the operations - # on 32-bit floats before converting the mean and variance back to fp16 - y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x - - # Compute true mean while keeping the dims for proper broadcasting. - shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True)) - - shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True) - mean = shifted_mean + shift - mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True) - - mean = array_ops.squeeze(mean, axes) - mean_squared = array_ops.squeeze(mean_squared, axes) - if x.dtype == dtypes.float16: - return (math_ops.cast(mean, dtypes.float16), - math_ops.cast(mean_squared, dtypes.float16)) - else: - return (mean, mean_squared) - - -def _validate_init_input_and_get_axis(reference_batch, axis): - """Validate input and return the used axis value.""" - if reference_batch.shape.ndims is None: - raise ValueError('`reference_batch` has unknown dimensions.') - - ndims = reference_batch.shape.ndims - if axis < 0: - used_axis = ndims + axis - else: - used_axis = axis - if used_axis < 0 or used_axis >= ndims: - raise ValueError('Value of `axis` argument ' + str(used_axis) + - ' is out of range for input with rank ' + str(ndims)) - return used_axis - - -def _validate_call_input(tensor_list, batch_dim): - """Verifies that tensor shapes are compatible, except for `batch_dim`.""" - def _get_shape(tensor): - shape = tensor.shape.as_list() - del shape[batch_dim] - return shape - base_shape = tensor_shape.TensorShape(_get_shape(tensor_list[0])) - for tensor in tensor_list: - base_shape.assert_is_compatible_with(_get_shape(tensor)) - - -class VBN(object): - """A class to perform virtual batch normalization. - - This technique was first introduced in `Improved Techniques for Training GANs` - (Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch - normalization on a minibatch, it fixes a reference subset of the data to use - for calculating normalization statistics. - - To do this, we calculate the reference batch mean and mean square, and modify - those statistics for each example. We use mean square instead of variance, - since it is linear. - - Note that if `center` or `scale` variables are created, they are shared - between all calls to this object. - - The `__init__` API is intended to mimic `tf.layers.batch_normalization` as - closely as possible. - """ - - def __init__(self, - reference_batch, - axis=-1, - epsilon=1e-3, - center=True, - scale=True, - beta_initializer=init_ops.zeros_initializer(), - gamma_initializer=init_ops.ones_initializer(), - beta_regularizer=None, - gamma_regularizer=None, - trainable=True, - name=None, - batch_axis=0): - """Initialize virtual batch normalization object. - - We precompute the 'mean' and 'mean squared' of the reference batch, so that - `__call__` is efficient. This means that the axis must be supplied when the - object is created, not when it is called. - - We precompute 'square mean' instead of 'variance', because the square mean - can be easily adjusted on a per-example basis. - - Args: - reference_batch: A minibatch tensors. This will form the reference data - from which the normalization statistics are calculated. See - https://arxiv.org/abs/1606.03498 for more details. - axis: Integer, the axis that should be normalized (typically the features - axis). For instance, after a `Convolution2D` layer with - `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. - epsilon: Small float added to variance to avoid dividing by zero. - center: If True, add offset of `beta` to normalized tensor. If False, - `beta` is ignored. - scale: If True, multiply by `gamma`. If False, `gamma` is - not used. When the next layer is linear (also e.g. `nn.relu`), this can - be disabled since the scaling can be done by the next layer. - beta_initializer: Initializer for the beta weight. - gamma_initializer: Initializer for the gamma weight. - beta_regularizer: Optional regularizer for the beta weight. - gamma_regularizer: Optional regularizer for the gamma weight. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). - name: String, the name of the ops. - batch_axis: The axis of the batch dimension. This dimension is treated - differently in `virtual batch normalization` vs `batch normalization`. - - Raises: - ValueError: If `reference_batch` has unknown dimensions at graph - construction. - ValueError: If `batch_axis` is the same as `axis`. - """ - axis = _validate_init_input_and_get_axis(reference_batch, axis) - self._epsilon = epsilon - self._beta = 0 - self._gamma = 1 - self._batch_axis = _validate_init_input_and_get_axis( - reference_batch, batch_axis) - - if axis == self._batch_axis: - raise ValueError('`axis` and `batch_axis` cannot be the same.') - - with variable_scope.variable_scope(name, 'VBN', - values=[reference_batch]) as self._vs: - self._reference_batch = reference_batch - - # Calculate important shapes: - # 1) Reduction axes for the reference batch - # 2) Broadcast shape, if necessary - # 3) Reduction axes for the virtual batchnormed batch - # 4) Shape for optional parameters - input_shape = self._reference_batch.shape - ndims = input_shape.ndims - reduction_axes = list(range(ndims)) - del reduction_axes[axis] - - self._broadcast_shape = [1] * len(input_shape) - self._broadcast_shape[axis] = input_shape[axis].value - - self._example_reduction_axes = list(range(ndims)) - del self._example_reduction_axes[max(axis, self._batch_axis)] - del self._example_reduction_axes[min(axis, self._batch_axis)] - - params_shape = self._reference_batch.shape[axis] - - # Determines whether broadcasting is needed. This is slightly different - # than in the `nn.batch_normalization` case, due to `batch_dim`. - self._needs_broadcasting = ( - sorted(self._example_reduction_axes) != list(range(ndims))[:-2]) - - # Calculate the sufficient statistics for the reference batch in a way - # that can be easily modified by additional examples. - self._ref_mean, self._ref_mean_squares = _statistics( - self._reference_batch, reduction_axes) - self._ref_variance = (self._ref_mean_squares - - math_ops.square(self._ref_mean)) - - # Virtual batch normalization uses a weighted average between example - # statistics and the reference batch statistics. - ref_batch_size = _static_or_dynamic_batch_size( - self._reference_batch, self._batch_axis) - self._example_weight = 1. / (math_ops.to_float(ref_batch_size) + 1.) - self._ref_weight = 1. - self._example_weight - - # Make the variables, if necessary. - if center: - self._beta = variable_scope.get_variable( - name='beta', - shape=(params_shape,), - initializer=beta_initializer, - regularizer=beta_regularizer, - trainable=trainable) - if scale: - self._gamma = variable_scope.get_variable( - name='gamma', - shape=(params_shape,), - initializer=gamma_initializer, - regularizer=gamma_regularizer, - trainable=trainable) - - def _virtual_statistics(self, inputs, reduction_axes): - """Compute the statistics needed for virtual batch normalization.""" - cur_mean, cur_mean_sq = _statistics(inputs, reduction_axes) - vb_mean = (self._example_weight * cur_mean + - self._ref_weight * self._ref_mean) - vb_mean_sq = (self._example_weight * cur_mean_sq + - self._ref_weight * self._ref_mean_squares) - return (vb_mean, vb_mean_sq) - - def _broadcast(self, v, broadcast_shape=None): - # The exact broadcast shape depends on the current batch, not the reference - # batch, unless we're calculating the batch normalization of the reference - # batch. - b_shape = broadcast_shape or self._broadcast_shape - if self._needs_broadcasting and v is not None: - return array_ops.reshape(v, b_shape) - return v - - def reference_batch_normalization(self): - """Return the reference batch, but batch normalized.""" - with ops.name_scope(self._vs.name): - return nn.batch_normalization(self._reference_batch, - self._broadcast(self._ref_mean), - self._broadcast(self._ref_variance), - self._broadcast(self._beta), - self._broadcast(self._gamma), - self._epsilon) - - def __call__(self, inputs): - """Run virtual batch normalization on inputs. - - Args: - inputs: Tensor input. - - Returns: - A virtual batch normalized version of `inputs`. - - Raises: - ValueError: If `inputs` shape isn't compatible with the reference batch. - """ - _validate_call_input([inputs, self._reference_batch], self._batch_axis) - - with ops.name_scope(self._vs.name, values=[inputs, self._reference_batch]): - # Calculate the statistics on the current input on a per-example basis. - vb_mean, vb_mean_sq = self._virtual_statistics( - inputs, self._example_reduction_axes) - vb_variance = vb_mean_sq - math_ops.square(vb_mean) - - # The exact broadcast shape of the input statistic Tensors depends on the - # current batch, not the reference batch. The parameter broadcast shape - # is independent of the shape of the input statistic Tensor dimensions. - b_shape = self._broadcast_shape[:] # deep copy - b_shape[self._batch_axis] = _static_or_dynamic_batch_size( - inputs, self._batch_axis) - return nn.batch_normalization( - inputs, - self._broadcast(vb_mean, b_shape), - self._broadcast(vb_variance, b_shape), - self._broadcast(self._beta, self._broadcast_shape), - self._broadcast(self._gamma, self._broadcast_shape), - self._epsilon) -- cgit v1.2.3-54-g00ecf