diff options
Diffstat (limited to 'lib/virtual_batch.py')
-rw-r--r-- | lib/virtual_batch.py | 39 |
1 files changed, 0 insertions, 39 deletions
diff --git a/lib/virtual_batch.py b/lib/virtual_batch.py deleted file mode 100644 index dab0419..0000000 --- a/lib/virtual_batch.py +++ /dev/null @@ -1,39 +0,0 @@ -import tensorflow as tf -from tensorflow.keras import backend as K -from tensorflow.keras.layers import Layer -from lib.virtual_batchnorm_impl import VBN -from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras.engine.base_layer import InputSpec -from tensorflow.python.keras import initializers - -class VirtualBatchNormalization(Layer): - def __init__(self, - momentum=0.99, - center=True, - scale=True, - beta_initializer='zeros', - gamma_initializer='ones', - beta_regularizer=None, - gamma_regularizer=None, - **kwargs): - - self.beta_initializer = initializers.get(beta_initializer) - self.gamma_initializer = initializers.get(gamma_initializer) - - super(VirtualBatchNormalization, self).__init__(**kwargs) - - def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - if not input_shape.ndims: - raise ValueError('Input has undefined rank:', input_shape) - ndims = len(input_shape) - self.input_spec = InputSpec(ndim=ndims) - #super(VirtualBatchNormalization, self).build(input_shape) # Be sure to call this at the end - - def call(self, x): - outputs = VBN(x, gamma_initializer=self.gamma_initializer, beta_initializer=self.beta_initializer)(x) - outputs.set_shape(x.get_shape()) - return outputs - - def compute_output_shape(self, input_shape): - return input_shape |