From bc501637cdb329db681b439563cdae418f3fa897 Mon Sep 17 00:00:00 2001 From: Vasil Zlatanov Date: Wed, 6 Mar 2019 20:39:00 +0000 Subject: Revert "Add virtual_batch support" This reverts commit 740e1b0c6a02a7bec20008758373f0dd80baade4. --- lib/virtual_batch.py | 39 --------------------------------------- 1 file changed, 39 deletions(-) delete mode 100644 lib/virtual_batch.py (limited to 'lib/virtual_batch.py') diff --git a/lib/virtual_batch.py b/lib/virtual_batch.py deleted file mode 100644 index dab0419..0000000 --- a/lib/virtual_batch.py +++ /dev/null @@ -1,39 +0,0 @@ -import tensorflow as tf -from tensorflow.keras import backend as K -from tensorflow.keras.layers import Layer -from lib.virtual_batchnorm_impl import VBN -from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras.engine.base_layer import InputSpec -from tensorflow.python.keras import initializers - -class VirtualBatchNormalization(Layer): - def __init__(self, - momentum=0.99, - center=True, - scale=True, - beta_initializer='zeros', - gamma_initializer='ones', - beta_regularizer=None, - gamma_regularizer=None, - **kwargs): - - self.beta_initializer = initializers.get(beta_initializer) - self.gamma_initializer = initializers.get(gamma_initializer) - - super(VirtualBatchNormalization, self).__init__(**kwargs) - - def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - if not input_shape.ndims: - raise ValueError('Input has undefined rank:', input_shape) - ndims = len(input_shape) - self.input_spec = InputSpec(ndim=ndims) - #super(VirtualBatchNormalization, self).build(input_shape) # Be sure to call this at the end - - def call(self, x): - outputs = VBN(x, gamma_initializer=self.gamma_initializer, beta_initializer=self.beta_initializer)(x) - outputs.set_shape(x.get_shape()) - return outputs - - def compute_output_shape(self, input_shape): - return input_shape -- cgit v1.2.3-54-g00ecf