diff options
author | Vasil Zlatanov <vasil@netcraft.com> | 2019-03-06 20:39:00 +0000 |
---|---|---|
committer | Vasil Zlatanov <vasil@netcraft.com> | 2019-03-06 20:39:00 +0000 |
commit | bc501637cdb329db681b439563cdae418f3fa897 (patch) | |
tree | c214be8307c7e64d8586104b3308b1073b9380fb /lib/virtual_batch.py | |
parent | f2d09edb7fb511364347ae9df1915a6655f45a0a (diff) | |
download | e4-gan-bc501637cdb329db681b439563cdae418f3fa897.tar.gz e4-gan-bc501637cdb329db681b439563cdae418f3fa897.tar.bz2 e4-gan-bc501637cdb329db681b439563cdae418f3fa897.zip |
Revert "Add virtual_batch support"
This reverts commit 740e1b0c6a02a7bec20008758373f0dd80baade4.
Diffstat (limited to 'lib/virtual_batch.py')
-rw-r--r-- | lib/virtual_batch.py | 39 |
1 files changed, 0 insertions, 39 deletions
diff --git a/lib/virtual_batch.py b/lib/virtual_batch.py deleted file mode 100644 index dab0419..0000000 --- a/lib/virtual_batch.py +++ /dev/null @@ -1,39 +0,0 @@ -import tensorflow as tf -from tensorflow.keras import backend as K -from tensorflow.keras.layers import Layer -from lib.virtual_batchnorm_impl import VBN -from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras.engine.base_layer import InputSpec -from tensorflow.python.keras import initializers - -class VirtualBatchNormalization(Layer): - def __init__(self, - momentum=0.99, - center=True, - scale=True, - beta_initializer='zeros', - gamma_initializer='ones', - beta_regularizer=None, - gamma_regularizer=None, - **kwargs): - - self.beta_initializer = initializers.get(beta_initializer) - self.gamma_initializer = initializers.get(gamma_initializer) - - super(VirtualBatchNormalization, self).__init__(**kwargs) - - def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - if not input_shape.ndims: - raise ValueError('Input has undefined rank:', input_shape) - ndims = len(input_shape) - self.input_spec = InputSpec(ndim=ndims) - #super(VirtualBatchNormalization, self).build(input_shape) # Be sure to call this at the end - - def call(self, x): - outputs = VBN(x, gamma_initializer=self.gamma_initializer, beta_initializer=self.beta_initializer)(x) - outputs.set_shape(x.get_shape()) - return outputs - - def compute_output_shape(self, input_shape): - return input_shape |