diff options
Diffstat (limited to 'lib/virtual_batchnorm_impl.py')
-rw-r--r-- | lib/virtual_batchnorm_impl.py | 306 |
1 files changed, 0 insertions, 306 deletions
diff --git a/lib/virtual_batchnorm_impl.py b/lib/virtual_batchnorm_impl.py deleted file mode 100644 index 650eab9..0000000 --- a/lib/virtual_batchnorm_impl.py +++ /dev/null @@ -1,306 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Virtual batch normalization. - -This technique was first introduced in `Improved Techniques for Training GANs` -(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch -normalization on a minibatch, it fixes a reference subset of the data to use for -calculating normalization statistics. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import variable_scope - -__all__ = [ - 'VBN', -] - - -def _static_or_dynamic_batch_size(tensor, batch_axis): - """Returns the static or dynamic batch size.""" - batch_size = array_ops.shape(tensor)[batch_axis] - static_batch_size = tensor_util.constant_value(batch_size) - return static_batch_size or batch_size - - -def _statistics(x, axes): - """Calculate the mean and mean square of `x`. - - Modified from the implementation of `tf.nn.moments`. - - Args: - x: A `Tensor`. - axes: Array of ints. Axes along which to compute mean and - variance. - - Returns: - Two `Tensor` objects: `mean` and `square mean`. - """ - # The dynamic range of fp16 is too limited to support the collection of - # sufficient statistics. As a workaround we simply perform the operations - # on 32-bit floats before converting the mean and variance back to fp16 - y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x - - # Compute true mean while keeping the dims for proper broadcasting. - shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True)) - - shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True) - mean = shifted_mean + shift - mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True) - - mean = array_ops.squeeze(mean, axes) - mean_squared = array_ops.squeeze(mean_squared, axes) - if x.dtype == dtypes.float16: - return (math_ops.cast(mean, dtypes.float16), - math_ops.cast(mean_squared, dtypes.float16)) - else: - return (mean, mean_squared) - - -def _validate_init_input_and_get_axis(reference_batch, axis): - """Validate input and return the used axis value.""" - if reference_batch.shape.ndims is None: - raise ValueError('`reference_batch` has unknown dimensions.') - - ndims = reference_batch.shape.ndims - if axis < 0: - used_axis = ndims + axis - else: - used_axis = axis - if used_axis < 0 or used_axis >= ndims: - raise ValueError('Value of `axis` argument ' + str(used_axis) + - ' is out of range for input with rank ' + str(ndims)) - return used_axis - - -def _validate_call_input(tensor_list, batch_dim): - """Verifies that tensor shapes are compatible, except for `batch_dim`.""" - def _get_shape(tensor): - shape = tensor.shape.as_list() - del shape[batch_dim] - return shape - base_shape = tensor_shape.TensorShape(_get_shape(tensor_list[0])) - for tensor in tensor_list: - base_shape.assert_is_compatible_with(_get_shape(tensor)) - - -class VBN(object): - """A class to perform virtual batch normalization. - - This technique was first introduced in `Improved Techniques for Training GANs` - (Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch - normalization on a minibatch, it fixes a reference subset of the data to use - for calculating normalization statistics. - - To do this, we calculate the reference batch mean and mean square, and modify - those statistics for each example. We use mean square instead of variance, - since it is linear. - - Note that if `center` or `scale` variables are created, they are shared - between all calls to this object. - - The `__init__` API is intended to mimic `tf.layers.batch_normalization` as - closely as possible. - """ - - def __init__(self, - reference_batch, - axis=-1, - epsilon=1e-3, - center=True, - scale=True, - beta_initializer=init_ops.zeros_initializer(), - gamma_initializer=init_ops.ones_initializer(), - beta_regularizer=None, - gamma_regularizer=None, - trainable=True, - name=None, - batch_axis=0): - """Initialize virtual batch normalization object. - - We precompute the 'mean' and 'mean squared' of the reference batch, so that - `__call__` is efficient. This means that the axis must be supplied when the - object is created, not when it is called. - - We precompute 'square mean' instead of 'variance', because the square mean - can be easily adjusted on a per-example basis. - - Args: - reference_batch: A minibatch tensors. This will form the reference data - from which the normalization statistics are calculated. See - https://arxiv.org/abs/1606.03498 for more details. - axis: Integer, the axis that should be normalized (typically the features - axis). For instance, after a `Convolution2D` layer with - `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. - epsilon: Small float added to variance to avoid dividing by zero. - center: If True, add offset of `beta` to normalized tensor. If False, - `beta` is ignored. - scale: If True, multiply by `gamma`. If False, `gamma` is - not used. When the next layer is linear (also e.g. `nn.relu`), this can - be disabled since the scaling can be done by the next layer. - beta_initializer: Initializer for the beta weight. - gamma_initializer: Initializer for the gamma weight. - beta_regularizer: Optional regularizer for the beta weight. - gamma_regularizer: Optional regularizer for the gamma weight. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). - name: String, the name of the ops. - batch_axis: The axis of the batch dimension. This dimension is treated - differently in `virtual batch normalization` vs `batch normalization`. - - Raises: - ValueError: If `reference_batch` has unknown dimensions at graph - construction. - ValueError: If `batch_axis` is the same as `axis`. - """ - axis = _validate_init_input_and_get_axis(reference_batch, axis) - self._epsilon = epsilon - self._beta = 0 - self._gamma = 1 - self._batch_axis = _validate_init_input_and_get_axis( - reference_batch, batch_axis) - - if axis == self._batch_axis: - raise ValueError('`axis` and `batch_axis` cannot be the same.') - - with variable_scope.variable_scope(name, 'VBN', - values=[reference_batch]) as self._vs: - self._reference_batch = reference_batch - - # Calculate important shapes: - # 1) Reduction axes for the reference batch - # 2) Broadcast shape, if necessary - # 3) Reduction axes for the virtual batchnormed batch - # 4) Shape for optional parameters - input_shape = self._reference_batch.shape - ndims = input_shape.ndims - reduction_axes = list(range(ndims)) - del reduction_axes[axis] - - self._broadcast_shape = [1] * len(input_shape) - self._broadcast_shape[axis] = input_shape[axis].value - - self._example_reduction_axes = list(range(ndims)) - del self._example_reduction_axes[max(axis, self._batch_axis)] - del self._example_reduction_axes[min(axis, self._batch_axis)] - - params_shape = self._reference_batch.shape[axis] - - # Determines whether broadcasting is needed. This is slightly different - # than in the `nn.batch_normalization` case, due to `batch_dim`. - self._needs_broadcasting = ( - sorted(self._example_reduction_axes) != list(range(ndims))[:-2]) - - # Calculate the sufficient statistics for the reference batch in a way - # that can be easily modified by additional examples. - self._ref_mean, self._ref_mean_squares = _statistics( - self._reference_batch, reduction_axes) - self._ref_variance = (self._ref_mean_squares - - math_ops.square(self._ref_mean)) - - # Virtual batch normalization uses a weighted average between example - # statistics and the reference batch statistics. - ref_batch_size = _static_or_dynamic_batch_size( - self._reference_batch, self._batch_axis) - self._example_weight = 1. / (math_ops.to_float(ref_batch_size) + 1.) - self._ref_weight = 1. - self._example_weight - - # Make the variables, if necessary. - if center: - self._beta = variable_scope.get_variable( - name='beta', - shape=(params_shape,), - initializer=beta_initializer, - regularizer=beta_regularizer, - trainable=trainable) - if scale: - self._gamma = variable_scope.get_variable( - name='gamma', - shape=(params_shape,), - initializer=gamma_initializer, - regularizer=gamma_regularizer, - trainable=trainable) - - def _virtual_statistics(self, inputs, reduction_axes): - """Compute the statistics needed for virtual batch normalization.""" - cur_mean, cur_mean_sq = _statistics(inputs, reduction_axes) - vb_mean = (self._example_weight * cur_mean + - self._ref_weight * self._ref_mean) - vb_mean_sq = (self._example_weight * cur_mean_sq + - self._ref_weight * self._ref_mean_squares) - return (vb_mean, vb_mean_sq) - - def _broadcast(self, v, broadcast_shape=None): - # The exact broadcast shape depends on the current batch, not the reference - # batch, unless we're calculating the batch normalization of the reference - # batch. - b_shape = broadcast_shape or self._broadcast_shape - if self._needs_broadcasting and v is not None: - return array_ops.reshape(v, b_shape) - return v - - def reference_batch_normalization(self): - """Return the reference batch, but batch normalized.""" - with ops.name_scope(self._vs.name): - return nn.batch_normalization(self._reference_batch, - self._broadcast(self._ref_mean), - self._broadcast(self._ref_variance), - self._broadcast(self._beta), - self._broadcast(self._gamma), - self._epsilon) - - def __call__(self, inputs): - """Run virtual batch normalization on inputs. - - Args: - inputs: Tensor input. - - Returns: - A virtual batch normalized version of `inputs`. - - Raises: - ValueError: If `inputs` shape isn't compatible with the reference batch. - """ - _validate_call_input([inputs, self._reference_batch], self._batch_axis) - - with ops.name_scope(self._vs.name, values=[inputs, self._reference_batch]): - # Calculate the statistics on the current input on a per-example basis. - vb_mean, vb_mean_sq = self._virtual_statistics( - inputs, self._example_reduction_axes) - vb_variance = vb_mean_sq - math_ops.square(vb_mean) - - # The exact broadcast shape of the input statistic Tensors depends on the - # current batch, not the reference batch. The parameter broadcast shape - # is independent of the shape of the input statistic Tensor dimensions. - b_shape = self._broadcast_shape[:] # deep copy - b_shape[self._batch_axis] = _static_or_dynamic_batch_size( - inputs, self._batch_axis) - return nn.batch_normalization( - inputs, - self._broadcast(vb_mean, b_shape), - self._broadcast(vb_variance, b_shape), - self._broadcast(self._beta, self._broadcast_shape), - self._broadcast(self._gamma, self._broadcast_shape), - self._epsilon) |