aboutsummaryrefslogtreecommitdiff
path: root/lib/virtual_batchnorm_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/virtual_batchnorm_impl.py')
-rw-r--r--lib/virtual_batchnorm_impl.py306
1 files changed, 306 insertions, 0 deletions
diff --git a/lib/virtual_batchnorm_impl.py b/lib/virtual_batchnorm_impl.py
new file mode 100644
index 0000000..650eab9
--- /dev/null
+++ b/lib/virtual_batchnorm_impl.py
@@ -0,0 +1,306 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Virtual batch normalization.
+
+This technique was first introduced in `Improved Techniques for Training GANs`
+(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch
+normalization on a minibatch, it fixes a reference subset of the data to use for
+calculating normalization statistics.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import variable_scope
+
+__all__ = [
+ 'VBN',
+]
+
+
+def _static_or_dynamic_batch_size(tensor, batch_axis):
+ """Returns the static or dynamic batch size."""
+ batch_size = array_ops.shape(tensor)[batch_axis]
+ static_batch_size = tensor_util.constant_value(batch_size)
+ return static_batch_size or batch_size
+
+
+def _statistics(x, axes):
+ """Calculate the mean and mean square of `x`.
+
+ Modified from the implementation of `tf.nn.moments`.
+
+ Args:
+ x: A `Tensor`.
+ axes: Array of ints. Axes along which to compute mean and
+ variance.
+
+ Returns:
+ Two `Tensor` objects: `mean` and `square mean`.
+ """
+ # The dynamic range of fp16 is too limited to support the collection of
+ # sufficient statistics. As a workaround we simply perform the operations
+ # on 32-bit floats before converting the mean and variance back to fp16
+ y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
+
+ # Compute true mean while keeping the dims for proper broadcasting.
+ shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True))
+
+ shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True)
+ mean = shifted_mean + shift
+ mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True)
+
+ mean = array_ops.squeeze(mean, axes)
+ mean_squared = array_ops.squeeze(mean_squared, axes)
+ if x.dtype == dtypes.float16:
+ return (math_ops.cast(mean, dtypes.float16),
+ math_ops.cast(mean_squared, dtypes.float16))
+ else:
+ return (mean, mean_squared)
+
+
+def _validate_init_input_and_get_axis(reference_batch, axis):
+ """Validate input and return the used axis value."""
+ if reference_batch.shape.ndims is None:
+ raise ValueError('`reference_batch` has unknown dimensions.')
+
+ ndims = reference_batch.shape.ndims
+ if axis < 0:
+ used_axis = ndims + axis
+ else:
+ used_axis = axis
+ if used_axis < 0 or used_axis >= ndims:
+ raise ValueError('Value of `axis` argument ' + str(used_axis) +
+ ' is out of range for input with rank ' + str(ndims))
+ return used_axis
+
+
+def _validate_call_input(tensor_list, batch_dim):
+ """Verifies that tensor shapes are compatible, except for `batch_dim`."""
+ def _get_shape(tensor):
+ shape = tensor.shape.as_list()
+ del shape[batch_dim]
+ return shape
+ base_shape = tensor_shape.TensorShape(_get_shape(tensor_list[0]))
+ for tensor in tensor_list:
+ base_shape.assert_is_compatible_with(_get_shape(tensor))
+
+
+class VBN(object):
+ """A class to perform virtual batch normalization.
+
+ This technique was first introduced in `Improved Techniques for Training GANs`
+ (Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch
+ normalization on a minibatch, it fixes a reference subset of the data to use
+ for calculating normalization statistics.
+
+ To do this, we calculate the reference batch mean and mean square, and modify
+ those statistics for each example. We use mean square instead of variance,
+ since it is linear.
+
+ Note that if `center` or `scale` variables are created, they are shared
+ between all calls to this object.
+
+ The `__init__` API is intended to mimic `tf.layers.batch_normalization` as
+ closely as possible.
+ """
+
+ def __init__(self,
+ reference_batch,
+ axis=-1,
+ epsilon=1e-3,
+ center=True,
+ scale=True,
+ beta_initializer=init_ops.zeros_initializer(),
+ gamma_initializer=init_ops.ones_initializer(),
+ beta_regularizer=None,
+ gamma_regularizer=None,
+ trainable=True,
+ name=None,
+ batch_axis=0):
+ """Initialize virtual batch normalization object.
+
+ We precompute the 'mean' and 'mean squared' of the reference batch, so that
+ `__call__` is efficient. This means that the axis must be supplied when the
+ object is created, not when it is called.
+
+ We precompute 'square mean' instead of 'variance', because the square mean
+ can be easily adjusted on a per-example basis.
+
+ Args:
+ reference_batch: A minibatch tensors. This will form the reference data
+ from which the normalization statistics are calculated. See
+ https://arxiv.org/abs/1606.03498 for more details.
+ axis: Integer, the axis that should be normalized (typically the features
+ axis). For instance, after a `Convolution2D` layer with
+ `data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
+ epsilon: Small float added to variance to avoid dividing by zero.
+ center: If True, add offset of `beta` to normalized tensor. If False,
+ `beta` is ignored.
+ scale: If True, multiply by `gamma`. If False, `gamma` is
+ not used. When the next layer is linear (also e.g. `nn.relu`), this can
+ be disabled since the scaling can be done by the next layer.
+ beta_initializer: Initializer for the beta weight.
+ gamma_initializer: Initializer for the gamma weight.
+ beta_regularizer: Optional regularizer for the beta weight.
+ gamma_regularizer: Optional regularizer for the gamma weight.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
+ name: String, the name of the ops.
+ batch_axis: The axis of the batch dimension. This dimension is treated
+ differently in `virtual batch normalization` vs `batch normalization`.
+
+ Raises:
+ ValueError: If `reference_batch` has unknown dimensions at graph
+ construction.
+ ValueError: If `batch_axis` is the same as `axis`.
+ """
+ axis = _validate_init_input_and_get_axis(reference_batch, axis)
+ self._epsilon = epsilon
+ self._beta = 0
+ self._gamma = 1
+ self._batch_axis = _validate_init_input_and_get_axis(
+ reference_batch, batch_axis)
+
+ if axis == self._batch_axis:
+ raise ValueError('`axis` and `batch_axis` cannot be the same.')
+
+ with variable_scope.variable_scope(name, 'VBN',
+ values=[reference_batch]) as self._vs:
+ self._reference_batch = reference_batch
+
+ # Calculate important shapes:
+ # 1) Reduction axes for the reference batch
+ # 2) Broadcast shape, if necessary
+ # 3) Reduction axes for the virtual batchnormed batch
+ # 4) Shape for optional parameters
+ input_shape = self._reference_batch.shape
+ ndims = input_shape.ndims
+ reduction_axes = list(range(ndims))
+ del reduction_axes[axis]
+
+ self._broadcast_shape = [1] * len(input_shape)
+ self._broadcast_shape[axis] = input_shape[axis].value
+
+ self._example_reduction_axes = list(range(ndims))
+ del self._example_reduction_axes[max(axis, self._batch_axis)]
+ del self._example_reduction_axes[min(axis, self._batch_axis)]
+
+ params_shape = self._reference_batch.shape[axis]
+
+ # Determines whether broadcasting is needed. This is slightly different
+ # than in the `nn.batch_normalization` case, due to `batch_dim`.
+ self._needs_broadcasting = (
+ sorted(self._example_reduction_axes) != list(range(ndims))[:-2])
+
+ # Calculate the sufficient statistics for the reference batch in a way
+ # that can be easily modified by additional examples.
+ self._ref_mean, self._ref_mean_squares = _statistics(
+ self._reference_batch, reduction_axes)
+ self._ref_variance = (self._ref_mean_squares -
+ math_ops.square(self._ref_mean))
+
+ # Virtual batch normalization uses a weighted average between example
+ # statistics and the reference batch statistics.
+ ref_batch_size = _static_or_dynamic_batch_size(
+ self._reference_batch, self._batch_axis)
+ self._example_weight = 1. / (math_ops.to_float(ref_batch_size) + 1.)
+ self._ref_weight = 1. - self._example_weight
+
+ # Make the variables, if necessary.
+ if center:
+ self._beta = variable_scope.get_variable(
+ name='beta',
+ shape=(params_shape,),
+ initializer=beta_initializer,
+ regularizer=beta_regularizer,
+ trainable=trainable)
+ if scale:
+ self._gamma = variable_scope.get_variable(
+ name='gamma',
+ shape=(params_shape,),
+ initializer=gamma_initializer,
+ regularizer=gamma_regularizer,
+ trainable=trainable)
+
+ def _virtual_statistics(self, inputs, reduction_axes):
+ """Compute the statistics needed for virtual batch normalization."""
+ cur_mean, cur_mean_sq = _statistics(inputs, reduction_axes)
+ vb_mean = (self._example_weight * cur_mean +
+ self._ref_weight * self._ref_mean)
+ vb_mean_sq = (self._example_weight * cur_mean_sq +
+ self._ref_weight * self._ref_mean_squares)
+ return (vb_mean, vb_mean_sq)
+
+ def _broadcast(self, v, broadcast_shape=None):
+ # The exact broadcast shape depends on the current batch, not the reference
+ # batch, unless we're calculating the batch normalization of the reference
+ # batch.
+ b_shape = broadcast_shape or self._broadcast_shape
+ if self._needs_broadcasting and v is not None:
+ return array_ops.reshape(v, b_shape)
+ return v
+
+ def reference_batch_normalization(self):
+ """Return the reference batch, but batch normalized."""
+ with ops.name_scope(self._vs.name):
+ return nn.batch_normalization(self._reference_batch,
+ self._broadcast(self._ref_mean),
+ self._broadcast(self._ref_variance),
+ self._broadcast(self._beta),
+ self._broadcast(self._gamma),
+ self._epsilon)
+
+ def __call__(self, inputs):
+ """Run virtual batch normalization on inputs.
+
+ Args:
+ inputs: Tensor input.
+
+ Returns:
+ A virtual batch normalized version of `inputs`.
+
+ Raises:
+ ValueError: If `inputs` shape isn't compatible with the reference batch.
+ """
+ _validate_call_input([inputs, self._reference_batch], self._batch_axis)
+
+ with ops.name_scope(self._vs.name, values=[inputs, self._reference_batch]):
+ # Calculate the statistics on the current input on a per-example basis.
+ vb_mean, vb_mean_sq = self._virtual_statistics(
+ inputs, self._example_reduction_axes)
+ vb_variance = vb_mean_sq - math_ops.square(vb_mean)
+
+ # The exact broadcast shape of the input statistic Tensors depends on the
+ # current batch, not the reference batch. The parameter broadcast shape
+ # is independent of the shape of the input statistic Tensor dimensions.
+ b_shape = self._broadcast_shape[:] # deep copy
+ b_shape[self._batch_axis] = _static_or_dynamic_batch_size(
+ inputs, self._batch_axis)
+ return nn.batch_normalization(
+ inputs,
+ self._broadcast(vb_mean, b_shape),
+ self._broadcast(vb_variance, b_shape),
+ self._broadcast(self._beta, self._broadcast_shape),
+ self._broadcast(self._gamma, self._broadcast_shape),
+ self._epsilon)