aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVasil Zlatanov <vasil@netcraft.com>2019-03-06 20:39:00 +0000
committerVasil Zlatanov <vasil@netcraft.com>2019-03-06 20:39:00 +0000
commitbc501637cdb329db681b439563cdae418f3fa897 (patch)
treec214be8307c7e64d8586104b3308b1073b9380fb
parentf2d09edb7fb511364347ae9df1915a6655f45a0a (diff)
downloade4-gan-bc501637cdb329db681b439563cdae418f3fa897.tar.gz
e4-gan-bc501637cdb329db681b439563cdae418f3fa897.tar.bz2
e4-gan-bc501637cdb329db681b439563cdae418f3fa897.zip
Revert "Add virtual_batch support"
This reverts commit 740e1b0c6a02a7bec20008758373f0dd80baade4.
-rwxr-xr-x[-rw-r--r--]cgan.py28
-rw-r--r--dcgan.py36
-rw-r--r--lib/__pycache__/virtual_batch.cpython-37.pycbin1758 -> 0 bytes
-rw-r--r--lib/__pycache__/virtual_batchnorm_impl.cpython-37.pycbin8723 -> 0 bytes
-rw-r--r--lib/virtual_batch.py39
-rw-r--r--lib/virtual_batchnorm_impl.py306
6 files changed, 22 insertions, 387 deletions
diff --git a/cgan.py b/cgan.py
index 45b9bb9..6406244 100644..100755
--- a/cgan.py
+++ b/cgan.py
@@ -1,23 +1,21 @@
from __future__ import print_function, division
import tensorflow.keras as keras
import tensorflow as tf
-from tensorflow.keras.datasets import mnist
-from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
-from tensorflow.keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
-from tensorflow.keras.layers import LeakyReLU
-from tensorflow.keras.layers import UpSampling2D, Conv2D
-from tensorflow.keras.models import Sequential, Model
-from tensorflow.keras.optimizers import Adam
+from keras.datasets import mnist
+from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
+from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
+from keras.layers.advanced_activations import LeakyReLU
+from keras.layers.convolutional import UpSampling2D, Conv2D
+from keras.models import Sequential, Model
+from keras.optimizers import Adam
import matplotlib.pyplot as plt
from IPython.display import clear_output
from tqdm import tqdm
-from lib.virtual_batch import VirtualBatchNormalization
-
import numpy as np
class CGAN():
- def __init__(self, dense_layers = 3, virtual_batch_normalization=False):
+ def __init__(self, dense_layers = 3):
# Input shape
self.img_rows = 28
self.img_cols = 28
@@ -26,7 +24,6 @@ class CGAN():
self.num_classes = 10
self.latent_dim = 100
self.dense_layers = dense_layers
- self.virtual_batch_normalization = virtual_batch_normalization
optimizer = Adam(0.0002, 0.5)
@@ -66,10 +63,7 @@ class CGAN():
output_size = 2**(8+i)
model.add(Dense(output_size, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
- if self.virtual_batch_normalization:
- model.add(VirtualBatchNormalization(momentum=0.8))
- else:
- model.add(BatchNormalization(momentum=0.8))
+ model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
@@ -142,7 +136,6 @@ class CGAN():
# Sample noise as generator input
noise = np.random.normal(0, 1, (batch_size, 100))
- tf.keras.backend.get_session().run(tf.global_variables_initializer())
# Generate a half batch of new images
gen_imgs = self.generator.predict([noise, labels])
@@ -224,9 +217,10 @@ class CGAN():
return train_data, test_data, val_data, labels_train, labels_test, labels_val
+
'''
if __name__ == '__main__':
- cgan = CGAN(dense_layers=1, virtual_batch_normalization=True)
+ cgan = CGAN(dense_layers=1)
cgan.train(epochs=7000, batch_size=32, sample_interval=200)
train, test, tr_labels, te_labels = cgan.generate_data()
print(train.shape, test.shape)
diff --git a/dcgan.py b/dcgan.py
index 21afaac..347f61e 100644
--- a/dcgan.py
+++ b/dcgan.py
@@ -1,16 +1,11 @@
from __future__ import print_function, division
-import tensorflow as keras
-
-import tensorflow as tf
-from tensorflow.keras.datasets import mnist
-from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
-from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
-from tensorflow.keras.layers import LeakyReLU
-from tensorflow.keras.layers import UpSampling2D, Conv2D
-from tensorflow.keras.models import Sequential, Model
-from tensorflow.keras.optimizers import Adam
-
-from lib.virtual_batch import VirtualBatchNormalization
+from keras.datasets import mnist
+from keras.layers import Input, Dense, Reshape, Flatten, Dropout
+from keras.layers import BatchNormalization, Activation, ZeroPadding2D
+from keras.layers.advanced_activations import LeakyReLU
+from keras.layers.convolutional import UpSampling2D, Conv2D
+from keras.models import Sequential, Model
+from keras.optimizers import Adam
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
@@ -22,7 +17,7 @@ import sys
import numpy as np
class DCGAN():
- def __init__(self, conv_layers = 1, virtual_batch_normalization=False):
+ def __init__(self, conv_layers = 1):
# Input shape
self.img_rows = 28
self.img_cols = 28
@@ -30,7 +25,6 @@ class DCGAN():
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
self.conv_layers = conv_layers
- self.virtual_batch_normalization = virtual_batch_normalization
optimizer = Adam(0.002, 0.5)
@@ -68,21 +62,14 @@ class DCGAN():
for i in range(self.conv_layers):
model.add(Conv2D(128, kernel_size=3, padding="same"))
- if self.virtual_batch_normalization:
- model.add(VirtualBatchNormalization())
- else:
- model.add(BatchNormalization())
+ model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(UpSampling2D())
for i in range(self.conv_layers):
model.add(Conv2D(64, kernel_size=3, padding="same"))
- if self.virtual_batch_normalization:
- model.add(VirtualBatchNormalization())
- else:
- model.add(BatchNormalization())
-
+ model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
@@ -151,7 +138,6 @@ class DCGAN():
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
- tf.keras.backend.get_session().run(tf.global_variables_initializer())
# Sample noise and generate a batch of new images
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
gen_imgs = self.generator.predict(noise)
@@ -203,6 +189,6 @@ class DCGAN():
'''
if __name__ == '__main__':
- dcgan = DCGAN(virtual_batch_normalization=True)
+ dcgan = DCGAN()
dcgan.train(epochs=4000, batch_size=32, save_interval=50)
'''
diff --git a/lib/__pycache__/virtual_batch.cpython-37.pyc b/lib/__pycache__/virtual_batch.cpython-37.pyc
deleted file mode 100644
index 1ca89c1..0000000
--- a/lib/__pycache__/virtual_batch.cpython-37.pyc
+++ /dev/null
Binary files differ
diff --git a/lib/__pycache__/virtual_batchnorm_impl.cpython-37.pyc b/lib/__pycache__/virtual_batchnorm_impl.cpython-37.pyc
deleted file mode 100644
index 1d41d7f..0000000
--- a/lib/__pycache__/virtual_batchnorm_impl.cpython-37.pyc
+++ /dev/null
Binary files differ
diff --git a/lib/virtual_batch.py b/lib/virtual_batch.py
deleted file mode 100644
index dab0419..0000000
--- a/lib/virtual_batch.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import tensorflow as tf
-from tensorflow.keras import backend as K
-from tensorflow.keras.layers import Layer
-from lib.virtual_batchnorm_impl import VBN
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.keras.engine.base_layer import InputSpec
-from tensorflow.python.keras import initializers
-
-class VirtualBatchNormalization(Layer):
- def __init__(self,
- momentum=0.99,
- center=True,
- scale=True,
- beta_initializer='zeros',
- gamma_initializer='ones',
- beta_regularizer=None,
- gamma_regularizer=None,
- **kwargs):
-
- self.beta_initializer = initializers.get(beta_initializer)
- self.gamma_initializer = initializers.get(gamma_initializer)
-
- super(VirtualBatchNormalization, self).__init__(**kwargs)
-
- def build(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape)
- if not input_shape.ndims:
- raise ValueError('Input has undefined rank:', input_shape)
- ndims = len(input_shape)
- self.input_spec = InputSpec(ndim=ndims)
- #super(VirtualBatchNormalization, self).build(input_shape) # Be sure to call this at the end
-
- def call(self, x):
- outputs = VBN(x, gamma_initializer=self.gamma_initializer, beta_initializer=self.beta_initializer)(x)
- outputs.set_shape(x.get_shape())
- return outputs
-
- def compute_output_shape(self, input_shape):
- return input_shape
diff --git a/lib/virtual_batchnorm_impl.py b/lib/virtual_batchnorm_impl.py
deleted file mode 100644
index 650eab9..0000000
--- a/lib/virtual_batchnorm_impl.py
+++ /dev/null
@@ -1,306 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Virtual batch normalization.
-
-This technique was first introduced in `Improved Techniques for Training GANs`
-(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch
-normalization on a minibatch, it fixes a reference subset of the data to use for
-calculating normalization statistics.
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn
-from tensorflow.python.ops import variable_scope
-
-__all__ = [
- 'VBN',
-]
-
-
-def _static_or_dynamic_batch_size(tensor, batch_axis):
- """Returns the static or dynamic batch size."""
- batch_size = array_ops.shape(tensor)[batch_axis]
- static_batch_size = tensor_util.constant_value(batch_size)
- return static_batch_size or batch_size
-
-
-def _statistics(x, axes):
- """Calculate the mean and mean square of `x`.
-
- Modified from the implementation of `tf.nn.moments`.
-
- Args:
- x: A `Tensor`.
- axes: Array of ints. Axes along which to compute mean and
- variance.
-
- Returns:
- Two `Tensor` objects: `mean` and `square mean`.
- """
- # The dynamic range of fp16 is too limited to support the collection of
- # sufficient statistics. As a workaround we simply perform the operations
- # on 32-bit floats before converting the mean and variance back to fp16
- y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
-
- # Compute true mean while keeping the dims for proper broadcasting.
- shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True))
-
- shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True)
- mean = shifted_mean + shift
- mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True)
-
- mean = array_ops.squeeze(mean, axes)
- mean_squared = array_ops.squeeze(mean_squared, axes)
- if x.dtype == dtypes.float16:
- return (math_ops.cast(mean, dtypes.float16),
- math_ops.cast(mean_squared, dtypes.float16))
- else:
- return (mean, mean_squared)
-
-
-def _validate_init_input_and_get_axis(reference_batch, axis):
- """Validate input and return the used axis value."""
- if reference_batch.shape.ndims is None:
- raise ValueError('`reference_batch` has unknown dimensions.')
-
- ndims = reference_batch.shape.ndims
- if axis < 0:
- used_axis = ndims + axis
- else:
- used_axis = axis
- if used_axis < 0 or used_axis >= ndims:
- raise ValueError('Value of `axis` argument ' + str(used_axis) +
- ' is out of range for input with rank ' + str(ndims))
- return used_axis
-
-
-def _validate_call_input(tensor_list, batch_dim):
- """Verifies that tensor shapes are compatible, except for `batch_dim`."""
- def _get_shape(tensor):
- shape = tensor.shape.as_list()
- del shape[batch_dim]
- return shape
- base_shape = tensor_shape.TensorShape(_get_shape(tensor_list[0]))
- for tensor in tensor_list:
- base_shape.assert_is_compatible_with(_get_shape(tensor))
-
-
-class VBN(object):
- """A class to perform virtual batch normalization.
-
- This technique was first introduced in `Improved Techniques for Training GANs`
- (Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch
- normalization on a minibatch, it fixes a reference subset of the data to use
- for calculating normalization statistics.
-
- To do this, we calculate the reference batch mean and mean square, and modify
- those statistics for each example. We use mean square instead of variance,
- since it is linear.
-
- Note that if `center` or `scale` variables are created, they are shared
- between all calls to this object.
-
- The `__init__` API is intended to mimic `tf.layers.batch_normalization` as
- closely as possible.
- """
-
- def __init__(self,
- reference_batch,
- axis=-1,
- epsilon=1e-3,
- center=True,
- scale=True,
- beta_initializer=init_ops.zeros_initializer(),
- gamma_initializer=init_ops.ones_initializer(),
- beta_regularizer=None,
- gamma_regularizer=None,
- trainable=True,
- name=None,
- batch_axis=0):
- """Initialize virtual batch normalization object.
-
- We precompute the 'mean' and 'mean squared' of the reference batch, so that
- `__call__` is efficient. This means that the axis must be supplied when the
- object is created, not when it is called.
-
- We precompute 'square mean' instead of 'variance', because the square mean
- can be easily adjusted on a per-example basis.
-
- Args:
- reference_batch: A minibatch tensors. This will form the reference data
- from which the normalization statistics are calculated. See
- https://arxiv.org/abs/1606.03498 for more details.
- axis: Integer, the axis that should be normalized (typically the features
- axis). For instance, after a `Convolution2D` layer with
- `data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
- epsilon: Small float added to variance to avoid dividing by zero.
- center: If True, add offset of `beta` to normalized tensor. If False,
- `beta` is ignored.
- scale: If True, multiply by `gamma`. If False, `gamma` is
- not used. When the next layer is linear (also e.g. `nn.relu`), this can
- be disabled since the scaling can be done by the next layer.
- beta_initializer: Initializer for the beta weight.
- gamma_initializer: Initializer for the gamma weight.
- beta_regularizer: Optional regularizer for the beta weight.
- gamma_regularizer: Optional regularizer for the gamma weight.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
- name: String, the name of the ops.
- batch_axis: The axis of the batch dimension. This dimension is treated
- differently in `virtual batch normalization` vs `batch normalization`.
-
- Raises:
- ValueError: If `reference_batch` has unknown dimensions at graph
- construction.
- ValueError: If `batch_axis` is the same as `axis`.
- """
- axis = _validate_init_input_and_get_axis(reference_batch, axis)
- self._epsilon = epsilon
- self._beta = 0
- self._gamma = 1
- self._batch_axis = _validate_init_input_and_get_axis(
- reference_batch, batch_axis)
-
- if axis == self._batch_axis:
- raise ValueError('`axis` and `batch_axis` cannot be the same.')
-
- with variable_scope.variable_scope(name, 'VBN',
- values=[reference_batch]) as self._vs:
- self._reference_batch = reference_batch
-
- # Calculate important shapes:
- # 1) Reduction axes for the reference batch
- # 2) Broadcast shape, if necessary
- # 3) Reduction axes for the virtual batchnormed batch
- # 4) Shape for optional parameters
- input_shape = self._reference_batch.shape
- ndims = input_shape.ndims
- reduction_axes = list(range(ndims))
- del reduction_axes[axis]
-
- self._broadcast_shape = [1] * len(input_shape)
- self._broadcast_shape[axis] = input_shape[axis].value
-
- self._example_reduction_axes = list(range(ndims))
- del self._example_reduction_axes[max(axis, self._batch_axis)]
- del self._example_reduction_axes[min(axis, self._batch_axis)]
-
- params_shape = self._reference_batch.shape[axis]
-
- # Determines whether broadcasting is needed. This is slightly different
- # than in the `nn.batch_normalization` case, due to `batch_dim`.
- self._needs_broadcasting = (
- sorted(self._example_reduction_axes) != list(range(ndims))[:-2])
-
- # Calculate the sufficient statistics for the reference batch in a way
- # that can be easily modified by additional examples.
- self._ref_mean, self._ref_mean_squares = _statistics(
- self._reference_batch, reduction_axes)
- self._ref_variance = (self._ref_mean_squares -
- math_ops.square(self._ref_mean))
-
- # Virtual batch normalization uses a weighted average between example
- # statistics and the reference batch statistics.
- ref_batch_size = _static_or_dynamic_batch_size(
- self._reference_batch, self._batch_axis)
- self._example_weight = 1. / (math_ops.to_float(ref_batch_size) + 1.)
- self._ref_weight = 1. - self._example_weight
-
- # Make the variables, if necessary.
- if center:
- self._beta = variable_scope.get_variable(
- name='beta',
- shape=(params_shape,),
- initializer=beta_initializer,
- regularizer=beta_regularizer,
- trainable=trainable)
- if scale:
- self._gamma = variable_scope.get_variable(
- name='gamma',
- shape=(params_shape,),
- initializer=gamma_initializer,
- regularizer=gamma_regularizer,
- trainable=trainable)
-
- def _virtual_statistics(self, inputs, reduction_axes):
- """Compute the statistics needed for virtual batch normalization."""
- cur_mean, cur_mean_sq = _statistics(inputs, reduction_axes)
- vb_mean = (self._example_weight * cur_mean +
- self._ref_weight * self._ref_mean)
- vb_mean_sq = (self._example_weight * cur_mean_sq +
- self._ref_weight * self._ref_mean_squares)
- return (vb_mean, vb_mean_sq)
-
- def _broadcast(self, v, broadcast_shape=None):
- # The exact broadcast shape depends on the current batch, not the reference
- # batch, unless we're calculating the batch normalization of the reference
- # batch.
- b_shape = broadcast_shape or self._broadcast_shape
- if self._needs_broadcasting and v is not None:
- return array_ops.reshape(v, b_shape)
- return v
-
- def reference_batch_normalization(self):
- """Return the reference batch, but batch normalized."""
- with ops.name_scope(self._vs.name):
- return nn.batch_normalization(self._reference_batch,
- self._broadcast(self._ref_mean),
- self._broadcast(self._ref_variance),
- self._broadcast(self._beta),
- self._broadcast(self._gamma),
- self._epsilon)
-
- def __call__(self, inputs):
- """Run virtual batch normalization on inputs.
-
- Args:
- inputs: Tensor input.
-
- Returns:
- A virtual batch normalized version of `inputs`.
-
- Raises:
- ValueError: If `inputs` shape isn't compatible with the reference batch.
- """
- _validate_call_input([inputs, self._reference_batch], self._batch_axis)
-
- with ops.name_scope(self._vs.name, values=[inputs, self._reference_batch]):
- # Calculate the statistics on the current input on a per-example basis.
- vb_mean, vb_mean_sq = self._virtual_statistics(
- inputs, self._example_reduction_axes)
- vb_variance = vb_mean_sq - math_ops.square(vb_mean)
-
- # The exact broadcast shape of the input statistic Tensors depends on the
- # current batch, not the reference batch. The parameter broadcast shape
- # is independent of the shape of the input statistic Tensor dimensions.
- b_shape = self._broadcast_shape[:] # deep copy
- b_shape[self._batch_axis] = _static_or_dynamic_batch_size(
- inputs, self._batch_axis)
- return nn.batch_normalization(
- inputs,
- self._broadcast(vb_mean, b_shape),
- self._broadcast(vb_variance, b_shape),
- self._broadcast(self._beta, self._broadcast_shape),
- self._broadcast(self._gamma, self._broadcast_shape),
- self._epsilon)