aboutsummaryrefslogtreecommitdiff
path: root/lib/virtual_batch.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/virtual_batch.py')
-rw-r--r--lib/virtual_batch.py39
1 files changed, 39 insertions, 0 deletions
diff --git a/lib/virtual_batch.py b/lib/virtual_batch.py
new file mode 100644
index 0000000..dab0419
--- /dev/null
+++ b/lib/virtual_batch.py
@@ -0,0 +1,39 @@
+import tensorflow as tf
+from tensorflow.keras import backend as K
+from tensorflow.keras.layers import Layer
+from lib.virtual_batchnorm_impl import VBN
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras import initializers
+
+class VirtualBatchNormalization(Layer):
+ def __init__(self,
+ momentum=0.99,
+ center=True,
+ scale=True,
+ beta_initializer='zeros',
+ gamma_initializer='ones',
+ beta_regularizer=None,
+ gamma_regularizer=None,
+ **kwargs):
+
+ self.beta_initializer = initializers.get(beta_initializer)
+ self.gamma_initializer = initializers.get(gamma_initializer)
+
+ super(VirtualBatchNormalization, self).__init__(**kwargs)
+
+ def build(self, input_shape):
+ input_shape = tensor_shape.TensorShape(input_shape)
+ if not input_shape.ndims:
+ raise ValueError('Input has undefined rank:', input_shape)
+ ndims = len(input_shape)
+ self.input_spec = InputSpec(ndim=ndims)
+ #super(VirtualBatchNormalization, self).build(input_shape) # Be sure to call this at the end
+
+ def call(self, x):
+ outputs = VBN(x, gamma_initializer=self.gamma_initializer, beta_initializer=self.beta_initializer)(x)
+ outputs.set_shape(x.get_shape())
+ return outputs
+
+ def compute_output_shape(self, input_shape):
+ return input_shape