aboutsummaryrefslogtreecommitdiff
path: root/lib/virtual_batch.py
blob: dab04197660d26150fef6b1e4c696bb3c46c60c2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Layer
from lib.virtual_batchnorm_impl import VBN
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras import initializers

class VirtualBatchNormalization(Layer):
    def __init__(self,
               momentum=0.99,
               center=True,
               scale=True,
               beta_initializer='zeros',
               gamma_initializer='ones',
               beta_regularizer=None,
               gamma_regularizer=None,
               **kwargs):

        self.beta_initializer = initializers.get(beta_initializer)
        self.gamma_initializer = initializers.get(gamma_initializer)
 
        super(VirtualBatchNormalization, self).__init__(**kwargs)

    def build(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)
        if not input_shape.ndims:
          raise ValueError('Input has undefined rank:', input_shape)
        ndims = len(input_shape)
        self.input_spec = InputSpec(ndim=ndims)
        #super(VirtualBatchNormalization, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        outputs = VBN(x, gamma_initializer=self.gamma_initializer, beta_initializer=self.beta_initializer)(x)
        outputs.set_shape(x.get_shape())
        return outputs

    def compute_output_shape(self, input_shape):
        return input_shape