aboutsummaryrefslogtreecommitdiff
path: root/lib/virtual_batch.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/virtual_batch.py')
-rw-r--r--lib/virtual_batch.py39
1 files changed, 0 insertions, 39 deletions
diff --git a/lib/virtual_batch.py b/lib/virtual_batch.py
deleted file mode 100644
index dab0419..0000000
--- a/lib/virtual_batch.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import tensorflow as tf
-from tensorflow.keras import backend as K
-from tensorflow.keras.layers import Layer
-from lib.virtual_batchnorm_impl import VBN
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.keras.engine.base_layer import InputSpec
-from tensorflow.python.keras import initializers
-
-class VirtualBatchNormalization(Layer):
- def __init__(self,
- momentum=0.99,
- center=True,
- scale=True,
- beta_initializer='zeros',
- gamma_initializer='ones',
- beta_regularizer=None,
- gamma_regularizer=None,
- **kwargs):
-
- self.beta_initializer = initializers.get(beta_initializer)
- self.gamma_initializer = initializers.get(gamma_initializer)
-
- super(VirtualBatchNormalization, self).__init__(**kwargs)
-
- def build(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape)
- if not input_shape.ndims:
- raise ValueError('Input has undefined rank:', input_shape)
- ndims = len(input_shape)
- self.input_spec = InputSpec(ndim=ndims)
- #super(VirtualBatchNormalization, self).build(input_shape) # Be sure to call this at the end
-
- def call(self, x):
- outputs = VBN(x, gamma_initializer=self.gamma_initializer, beta_initializer=self.beta_initializer)(x)
- outputs.set_shape(x.get_shape())
- return outputs
-
- def compute_output_shape(self, input_shape):
- return input_shape