aboutsummaryrefslogtreecommitdiff
path: root/lib/virtual_batch.py
diff options
context:
space:
mode:
authorVasil Zlatanov <vasil@netcraft.com>2019-03-06 20:39:00 +0000
committerVasil Zlatanov <vasil@netcraft.com>2019-03-06 20:39:00 +0000
commitbc501637cdb329db681b439563cdae418f3fa897 (patch)
treec214be8307c7e64d8586104b3308b1073b9380fb /lib/virtual_batch.py
parentf2d09edb7fb511364347ae9df1915a6655f45a0a (diff)
downloade4-gan-bc501637cdb329db681b439563cdae418f3fa897.tar.gz
e4-gan-bc501637cdb329db681b439563cdae418f3fa897.tar.bz2
e4-gan-bc501637cdb329db681b439563cdae418f3fa897.zip
Revert "Add virtual_batch support"
This reverts commit 740e1b0c6a02a7bec20008758373f0dd80baade4.
Diffstat (limited to 'lib/virtual_batch.py')
-rw-r--r--lib/virtual_batch.py39
1 files changed, 0 insertions, 39 deletions
diff --git a/lib/virtual_batch.py b/lib/virtual_batch.py
deleted file mode 100644
index dab0419..0000000
--- a/lib/virtual_batch.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import tensorflow as tf
-from tensorflow.keras import backend as K
-from tensorflow.keras.layers import Layer
-from lib.virtual_batchnorm_impl import VBN
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.keras.engine.base_layer import InputSpec
-from tensorflow.python.keras import initializers
-
-class VirtualBatchNormalization(Layer):
- def __init__(self,
- momentum=0.99,
- center=True,
- scale=True,
- beta_initializer='zeros',
- gamma_initializer='ones',
- beta_regularizer=None,
- gamma_regularizer=None,
- **kwargs):
-
- self.beta_initializer = initializers.get(beta_initializer)
- self.gamma_initializer = initializers.get(gamma_initializer)
-
- super(VirtualBatchNormalization, self).__init__(**kwargs)
-
- def build(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape)
- if not input_shape.ndims:
- raise ValueError('Input has undefined rank:', input_shape)
- ndims = len(input_shape)
- self.input_spec = InputSpec(ndim=ndims)
- #super(VirtualBatchNormalization, self).build(input_shape) # Be sure to call this at the end
-
- def call(self, x):
- outputs = VBN(x, gamma_initializer=self.gamma_initializer, beta_initializer=self.beta_initializer)(x)
- outputs.set_shape(x.get_shape())
- return outputs
-
- def compute_output_shape(self, input_shape):
- return input_shape