aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVasil Zlatanov <v@skozl.com>2019-03-05 14:29:29 +0000
committerVasil Zlatanov <v@skozl.com>2019-03-05 14:30:03 +0000
commit740e1b0c6a02a7bec20008758373f0dd80baade4 (patch)
treeeb9795d8013c8eb44a01176979ede6bef600ec3f
parent802f52a2410ed20cea55e8c097b3875111a80824 (diff)
downloade4-gan-740e1b0c6a02a7bec20008758373f0dd80baade4.tar.gz
e4-gan-740e1b0c6a02a7bec20008758373f0dd80baade4.tar.bz2
e4-gan-740e1b0c6a02a7bec20008758373f0dd80baade4.zip
Add virtual_batch support
-rw-r--r--[-rwxr-xr-x]cgan.py28
-rw-r--r--dcgan.py34
-rw-r--r--lib/__pycache__/virtual_batch.cpython-37.pycbin0 -> 1758 bytes
-rw-r--r--lib/__pycache__/virtual_batchnorm_impl.cpython-37.pycbin0 -> 8723 bytes
-rw-r--r--lib/virtual_batch.py39
-rw-r--r--lib/virtual_batchnorm_impl.py306
6 files changed, 386 insertions, 21 deletions
diff --git a/cgan.py b/cgan.py
index 5ab0c10..b9928f0 100755..100644
--- a/cgan.py
+++ b/cgan.py
@@ -1,21 +1,23 @@
from __future__ import print_function, division
import tensorflow.keras as keras
import tensorflow as tf
-from keras.datasets import mnist
-from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
-from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
-from keras.layers.advanced_activations import LeakyReLU
-from keras.layers.convolutional import UpSampling2D, Conv2D
-from keras.models import Sequential, Model
-from keras.optimizers import Adam
+from tensorflow.keras.datasets import mnist
+from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
+from tensorflow.keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
+from tensorflow.keras.layers import LeakyReLU
+from tensorflow.keras.layers import UpSampling2D, Conv2D
+from tensorflow.keras.models import Sequential, Model
+from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from IPython.display import clear_output
from tqdm import tqdm
+from lib.virtual_batch import VirtualBatchNormalization
+
import numpy as np
class CGAN():
- def __init__(self, dense_layers = 3):
+ def __init__(self, dense_layers = 3, virtual_batch_normalization=False):
# Input shape
self.img_rows = 28
self.img_cols = 28
@@ -24,6 +26,7 @@ class CGAN():
self.num_classes = 10
self.latent_dim = 100
self.dense_layers = dense_layers
+ self.virtual_batch_normalization = virtual_batch_normalization
optimizer = Adam(0.0002, 0.5)
@@ -63,7 +66,10 @@ class CGAN():
output_size = 2**(8+i)
model.add(Dense(output_size, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
- model.add(BatchNormalization(momentum=0.8))
+ if self.virtual_batch_normalization:
+ model.add(VirtualBatchNormalization(momentum=0.8))
+ else:
+ model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
@@ -136,6 +142,7 @@ class CGAN():
# Sample noise as generator input
noise = np.random.normal(0, 1, (batch_size, 100))
+ tf.keras.backend.get_session().run(tf.global_variables_initializer())
# Generate a half batch of new images
gen_imgs = self.generator.predict([noise, labels])
@@ -217,10 +224,9 @@ class CGAN():
return train_data, test_data, val_data, labels_train, labels_test, labels_val
-
'''
if __name__ == '__main__':
- cgan = CGAN(dense_layers=1)
+ cgan = CGAN(dense_layers=1, virtual_batch_normalization=True)
cgan.train(epochs=7000, batch_size=32, sample_interval=200)
train, test, tr_labels, te_labels = cgan.generate_data()
print(train.shape, test.shape)
diff --git a/dcgan.py b/dcgan.py
index bc7e14e..eca1852 100644
--- a/dcgan.py
+++ b/dcgan.py
@@ -1,11 +1,15 @@
from __future__ import print_function, division
-from keras.datasets import mnist
-from keras.layers import Input, Dense, Reshape, Flatten, Dropout
-from keras.layers import BatchNormalization, Activation, ZeroPadding2D
-from keras.layers.advanced_activations import LeakyReLU
-from keras.layers.convolutional import UpSampling2D, Conv2D
-from keras.models import Sequential, Model
-from keras.optimizers import Adam
+import tensorflow as keras
+
+from tensorflow.keras.datasets import mnist
+from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
+from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
+from tensorflow.keras.layers.advanced_activations import LeakyReLU
+from tensorflow.keras.layers.convolutional import UpSampling2D, Conv2D
+from tensorflow.keras.models import Sequential, Model
+from tensorflow.keras.optimizers import Adam
+
+from lib/virtual_batch import VirtualBatchNormalization
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
@@ -17,7 +21,7 @@ import sys
import numpy as np
class DCGAN():
- def __init__(self, conv_layers = 1):
+ def __init__(self, conv_layers = 1, virtual_batch_normalization=False):
# Input shape
self.img_rows = 28
self.img_cols = 28
@@ -25,6 +29,7 @@ class DCGAN():
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
self.conv_layers = conv_layers
+ self.virtual_batch_normalization = virtual_batch_normalization
optimizer = Adam(0.002, 0.5)
@@ -62,14 +67,21 @@ class DCGAN():
for i in range(self.conv_layers):
model.add(Conv2D(128, kernel_size=3, padding="same"))
- model.add(BatchNormalization())
+ if self.virtual_batch_normalization:
+ model.add(VirtualBatchNormalization())
+ else:
+ model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(UpSampling2D())
for i in range(self.conv_layers):
model.add(Conv2D(64, kernel_size=3, padding="same"))
- model.add(BatchNormalization())
+ if self.virtual_batch_normalization:
+ model.add(VirtualBatchNormalization())
+ else:
+ model.add(BatchNormalization())
+
model.add(Activation("relu"))
model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
@@ -137,6 +149,8 @@ class DCGAN():
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
+ tf.keras.backend.get_session().run(tf.global_variables_initializer())
+
# Sample noise and generate a batch of new images
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
gen_imgs = self.generator.predict(noise)
diff --git a/lib/__pycache__/virtual_batch.cpython-37.pyc b/lib/__pycache__/virtual_batch.cpython-37.pyc
new file mode 100644
index 0000000..1ca89c1
--- /dev/null
+++ b/lib/__pycache__/virtual_batch.cpython-37.pyc
Binary files differ
diff --git a/lib/__pycache__/virtual_batchnorm_impl.cpython-37.pyc b/lib/__pycache__/virtual_batchnorm_impl.cpython-37.pyc
new file mode 100644
index 0000000..1d41d7f
--- /dev/null
+++ b/lib/__pycache__/virtual_batchnorm_impl.cpython-37.pyc
Binary files differ
diff --git a/lib/virtual_batch.py b/lib/virtual_batch.py
new file mode 100644
index 0000000..dab0419
--- /dev/null
+++ b/lib/virtual_batch.py
@@ -0,0 +1,39 @@
+import tensorflow as tf
+from tensorflow.keras import backend as K
+from tensorflow.keras.layers import Layer
+from lib.virtual_batchnorm_impl import VBN
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras import initializers
+
+class VirtualBatchNormalization(Layer):
+ def __init__(self,
+ momentum=0.99,
+ center=True,
+ scale=True,
+ beta_initializer='zeros',
+ gamma_initializer='ones',
+ beta_regularizer=None,
+ gamma_regularizer=None,
+ **kwargs):
+
+ self.beta_initializer = initializers.get(beta_initializer)
+ self.gamma_initializer = initializers.get(gamma_initializer)
+
+ super(VirtualBatchNormalization, self).__init__(**kwargs)
+
+ def build(self, input_shape):
+ input_shape = tensor_shape.TensorShape(input_shape)
+ if not input_shape.ndims:
+ raise ValueError('Input has undefined rank:', input_shape)
+ ndims = len(input_shape)
+ self.input_spec = InputSpec(ndim=ndims)
+ #super(VirtualBatchNormalization, self).build(input_shape) # Be sure to call this at the end
+
+ def call(self, x):
+ outputs = VBN(x, gamma_initializer=self.gamma_initializer, beta_initializer=self.beta_initializer)(x)
+ outputs.set_shape(x.get_shape())
+ return outputs
+
+ def compute_output_shape(self, input_shape):
+ return input_shape
diff --git a/lib/virtual_batchnorm_impl.py b/lib/virtual_batchnorm_impl.py
new file mode 100644
index 0000000..650eab9
--- /dev/null
+++ b/lib/virtual_batchnorm_impl.py
@@ -0,0 +1,306 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Virtual batch normalization.
+
+This technique was first introduced in `Improved Techniques for Training GANs`
+(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch
+normalization on a minibatch, it fixes a reference subset of the data to use for
+calculating normalization statistics.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import variable_scope
+
+__all__ = [
+ 'VBN',
+]
+
+
+def _static_or_dynamic_batch_size(tensor, batch_axis):
+ """Returns the static or dynamic batch size."""
+ batch_size = array_ops.shape(tensor)[batch_axis]
+ static_batch_size = tensor_util.constant_value(batch_size)
+ return static_batch_size or batch_size
+
+
+def _statistics(x, axes):
+ """Calculate the mean and mean square of `x`.
+
+ Modified from the implementation of `tf.nn.moments`.
+
+ Args:
+ x: A `Tensor`.
+ axes: Array of ints. Axes along which to compute mean and
+ variance.
+
+ Returns:
+ Two `Tensor` objects: `mean` and `square mean`.
+ """
+ # The dynamic range of fp16 is too limited to support the collection of
+ # sufficient statistics. As a workaround we simply perform the operations
+ # on 32-bit floats before converting the mean and variance back to fp16
+ y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
+
+ # Compute true mean while keeping the dims for proper broadcasting.
+ shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True))
+
+ shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True)
+ mean = shifted_mean + shift
+ mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True)
+
+ mean = array_ops.squeeze(mean, axes)
+ mean_squared = array_ops.squeeze(mean_squared, axes)
+ if x.dtype == dtypes.float16:
+ return (math_ops.cast(mean, dtypes.float16),
+ math_ops.cast(mean_squared, dtypes.float16))
+ else:
+ return (mean, mean_squared)
+
+
+def _validate_init_input_and_get_axis(reference_batch, axis):
+ """Validate input and return the used axis value."""
+ if reference_batch.shape.ndims is None:
+ raise ValueError('`reference_batch` has unknown dimensions.')
+
+ ndims = reference_batch.shape.ndims
+ if axis < 0:
+ used_axis = ndims + axis
+ else:
+ used_axis = axis
+ if used_axis < 0 or used_axis >= ndims:
+ raise ValueError('Value of `axis` argument ' + str(used_axis) +
+ ' is out of range for input with rank ' + str(ndims))
+ return used_axis
+
+
+def _validate_call_input(tensor_list, batch_dim):
+ """Verifies that tensor shapes are compatible, except for `batch_dim`."""
+ def _get_shape(tensor):
+ shape = tensor.shape.as_list()
+ del shape[batch_dim]
+ return shape
+ base_shape = tensor_shape.TensorShape(_get_shape(tensor_list[0]))
+ for tensor in tensor_list:
+ base_shape.assert_is_compatible_with(_get_shape(tensor))
+
+
+class VBN(object):
+ """A class to perform virtual batch normalization.
+
+ This technique was first introduced in `Improved Techniques for Training GANs`
+ (Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch
+ normalization on a minibatch, it fixes a reference subset of the data to use
+ for calculating normalization statistics.
+
+ To do this, we calculate the reference batch mean and mean square, and modify
+ those statistics for each example. We use mean square instead of variance,
+ since it is linear.
+
+ Note that if `center` or `scale` variables are created, they are shared
+ between all calls to this object.
+
+ The `__init__` API is intended to mimic `tf.layers.batch_normalization` as
+ closely as possible.
+ """
+
+ def __init__(self,
+ reference_batch,
+ axis=-1,
+ epsilon=1e-3,
+ center=True,
+ scale=True,
+ beta_initializer=init_ops.zeros_initializer(),
+ gamma_initializer=init_ops.ones_initializer(),
+ beta_regularizer=None,
+ gamma_regularizer=None,
+ trainable=True,
+ name=None,
+ batch_axis=0):
+ """Initialize virtual batch normalization object.
+
+ We precompute the 'mean' and 'mean squared' of the reference batch, so that
+ `__call__` is efficient. This means that the axis must be supplied when the
+ object is created, not when it is called.
+
+ We precompute 'square mean' instead of 'variance', because the square mean
+ can be easily adjusted on a per-example basis.
+
+ Args:
+ reference_batch: A minibatch tensors. This will form the reference data
+ from which the normalization statistics are calculated. See
+ https://arxiv.org/abs/1606.03498 for more details.
+ axis: Integer, the axis that should be normalized (typically the features
+ axis). For instance, after a `Convolution2D` layer with
+ `data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
+ epsilon: Small float added to variance to avoid dividing by zero.
+ center: If True, add offset of `beta` to normalized tensor. If False,
+ `beta` is ignored.
+ scale: If True, multiply by `gamma`. If False, `gamma` is
+ not used. When the next layer is linear (also e.g. `nn.relu`), this can
+ be disabled since the scaling can be done by the next layer.
+ beta_initializer: Initializer for the beta weight.
+ gamma_initializer: Initializer for the gamma weight.
+ beta_regularizer: Optional regularizer for the beta weight.
+ gamma_regularizer: Optional regularizer for the gamma weight.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
+ name: String, the name of the ops.
+ batch_axis: The axis of the batch dimension. This dimension is treated
+ differently in `virtual batch normalization` vs `batch normalization`.
+
+ Raises:
+ ValueError: If `reference_batch` has unknown dimensions at graph
+ construction.
+ ValueError: If `batch_axis` is the same as `axis`.
+ """
+ axis = _validate_init_input_and_get_axis(reference_batch, axis)
+ self._epsilon = epsilon
+ self._beta = 0
+ self._gamma = 1
+ self._batch_axis = _validate_init_input_and_get_axis(
+ reference_batch, batch_axis)
+
+ if axis == self._batch_axis:
+ raise ValueError('`axis` and `batch_axis` cannot be the same.')
+
+ with variable_scope.variable_scope(name, 'VBN',
+ values=[reference_batch]) as self._vs:
+ self._reference_batch = reference_batch
+
+ # Calculate important shapes:
+ # 1) Reduction axes for the reference batch
+ # 2) Broadcast shape, if necessary
+ # 3) Reduction axes for the virtual batchnormed batch
+ # 4) Shape for optional parameters
+ input_shape = self._reference_batch.shape
+ ndims = input_shape.ndims
+ reduction_axes = list(range(ndims))
+ del reduction_axes[axis]
+
+ self._broadcast_shape = [1] * len(input_shape)
+ self._broadcast_shape[axis] = input_shape[axis].value
+
+ self._example_reduction_axes = list(range(ndims))
+ del self._example_reduction_axes[max(axis, self._batch_axis)]
+ del self._example_reduction_axes[min(axis, self._batch_axis)]
+
+ params_shape = self._reference_batch.shape[axis]
+
+ # Determines whether broadcasting is needed. This is slightly different
+ # than in the `nn.batch_normalization` case, due to `batch_dim`.
+ self._needs_broadcasting = (
+ sorted(self._example_reduction_axes) != list(range(ndims))[:-2])
+
+ # Calculate the sufficient statistics for the reference batch in a way
+ # that can be easily modified by additional examples.
+ self._ref_mean, self._ref_mean_squares = _statistics(
+ self._reference_batch, reduction_axes)
+ self._ref_variance = (self._ref_mean_squares -
+ math_ops.square(self._ref_mean))
+
+ # Virtual batch normalization uses a weighted average between example
+ # statistics and the reference batch statistics.
+ ref_batch_size = _static_or_dynamic_batch_size(
+ self._reference_batch, self._batch_axis)
+ self._example_weight = 1. / (math_ops.to_float(ref_batch_size) + 1.)
+ self._ref_weight = 1. - self._example_weight
+
+ # Make the variables, if necessary.
+ if center:
+ self._beta = variable_scope.get_variable(
+ name='beta',
+ shape=(params_shape,),
+ initializer=beta_initializer,
+ regularizer=beta_regularizer,
+ trainable=trainable)
+ if scale:
+ self._gamma = variable_scope.get_variable(
+ name='gamma',
+ shape=(params_shape,),
+ initializer=gamma_initializer,
+ regularizer=gamma_regularizer,
+ trainable=trainable)
+
+ def _virtual_statistics(self, inputs, reduction_axes):
+ """Compute the statistics needed for virtual batch normalization."""
+ cur_mean, cur_mean_sq = _statistics(inputs, reduction_axes)
+ vb_mean = (self._example_weight * cur_mean +
+ self._ref_weight * self._ref_mean)
+ vb_mean_sq = (self._example_weight * cur_mean_sq +
+ self._ref_weight * self._ref_mean_squares)
+ return (vb_mean, vb_mean_sq)
+
+ def _broadcast(self, v, broadcast_shape=None):
+ # The exact broadcast shape depends on the current batch, not the reference
+ # batch, unless we're calculating the batch normalization of the reference
+ # batch.
+ b_shape = broadcast_shape or self._broadcast_shape
+ if self._needs_broadcasting and v is not None:
+ return array_ops.reshape(v, b_shape)
+ return v
+
+ def reference_batch_normalization(self):
+ """Return the reference batch, but batch normalized."""
+ with ops.name_scope(self._vs.name):
+ return nn.batch_normalization(self._reference_batch,
+ self._broadcast(self._ref_mean),
+ self._broadcast(self._ref_variance),
+ self._broadcast(self._beta),
+ self._broadcast(self._gamma),
+ self._epsilon)
+
+ def __call__(self, inputs):
+ """Run virtual batch normalization on inputs.
+
+ Args:
+ inputs: Tensor input.
+
+ Returns:
+ A virtual batch normalized version of `inputs`.
+
+ Raises:
+ ValueError: If `inputs` shape isn't compatible with the reference batch.
+ """
+ _validate_call_input([inputs, self._reference_batch], self._batch_axis)
+
+ with ops.name_scope(self._vs.name, values=[inputs, self._reference_batch]):
+ # Calculate the statistics on the current input on a per-example basis.
+ vb_mean, vb_mean_sq = self._virtual_statistics(
+ inputs, self._example_reduction_axes)
+ vb_variance = vb_mean_sq - math_ops.square(vb_mean)
+
+ # The exact broadcast shape of the input statistic Tensors depends on the
+ # current batch, not the reference batch. The parameter broadcast shape
+ # is independent of the shape of the input statistic Tensor dimensions.
+ b_shape = self._broadcast_shape[:] # deep copy
+ b_shape[self._batch_axis] = _static_or_dynamic_batch_size(
+ inputs, self._batch_axis)
+ return nn.batch_normalization(
+ inputs,
+ self._broadcast(vb_mean, b_shape),
+ self._broadcast(vb_variance, b_shape),
+ self._broadcast(self._beta, self._broadcast_shape),
+ self._broadcast(self._gamma, self._broadcast_shape),
+ self._epsilon)