diff options
Diffstat (limited to 'lib/virtual_batch.py')
-rw-r--r-- | lib/virtual_batch.py | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/lib/virtual_batch.py b/lib/virtual_batch.py new file mode 100644 index 0000000..dab0419 --- /dev/null +++ b/lib/virtual_batch.py @@ -0,0 +1,39 @@ +import tensorflow as tf +from tensorflow.keras import backend as K +from tensorflow.keras.layers import Layer +from lib.virtual_batchnorm_impl import VBN +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras import initializers + +class VirtualBatchNormalization(Layer): + def __init__(self, + momentum=0.99, + center=True, + scale=True, + beta_initializer='zeros', + gamma_initializer='ones', + beta_regularizer=None, + gamma_regularizer=None, + **kwargs): + + self.beta_initializer = initializers.get(beta_initializer) + self.gamma_initializer = initializers.get(gamma_initializer) + + super(VirtualBatchNormalization, self).__init__(**kwargs) + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + if not input_shape.ndims: + raise ValueError('Input has undefined rank:', input_shape) + ndims = len(input_shape) + self.input_spec = InputSpec(ndim=ndims) + #super(VirtualBatchNormalization, self).build(input_shape) # Be sure to call this at the end + + def call(self, x): + outputs = VBN(x, gamma_initializer=self.gamma_initializer, beta_initializer=self.beta_initializer)(x) + outputs.set_shape(x.get_shape()) + return outputs + + def compute_output_shape(self, input_shape): + return input_shape |