import tensorflow as tf from tensorflow.keras import backend as K from tensorflow.keras.layers import Layer from lib.virtual_batchnorm_impl import VBN from tensorflow.python.framework import tensor_shape from tensorflow.python.keras.engine.base_layer import InputSpec from tensorflow.python.keras import initializers class VirtualBatchNormalization(Layer): def __init__(self, momentum=0.99, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones', beta_regularizer=None, gamma_regularizer=None, **kwargs): self.beta_initializer = initializers.get(beta_initializer) self.gamma_initializer = initializers.get(gamma_initializer) super(VirtualBatchNormalization, self).__init__(**kwargs) def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if not input_shape.ndims: raise ValueError('Input has undefined rank:', input_shape) ndims = len(input_shape) self.input_spec = InputSpec(ndim=ndims) #super(VirtualBatchNormalization, self).build(input_shape) # Be sure to call this at the end def call(self, x): outputs = VBN(x, gamma_initializer=self.gamma_initializer, beta_initializer=self.beta_initializer)(x) outputs.set_shape(x.get_shape()) return outputs def compute_output_shape(self, input_shape): return input_shape