1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
|
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Layer
from lib.virtual_batchnorm_impl import VBN
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras import initializers
class VirtualBatchNormalization(Layer):
def __init__(self,
momentum=0.99,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
**kwargs):
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
super(VirtualBatchNormalization, self).__init__(**kwargs)
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if not input_shape.ndims:
raise ValueError('Input has undefined rank:', input_shape)
ndims = len(input_shape)
self.input_spec = InputSpec(ndim=ndims)
#super(VirtualBatchNormalization, self).build(input_shape) # Be sure to call this at the end
def call(self, x):
outputs = VBN(x, gamma_initializer=self.gamma_initializer, beta_initializer=self.beta_initializer)(x)
outputs.set_shape(x.get_shape())
return outputs
def compute_output_shape(self, input_shape):
return input_shape
|