diff options
author | Vasil Zlatanov <v@skozl.com> | 2019-03-05 14:29:29 +0000 |
---|---|---|
committer | Vasil Zlatanov <v@skozl.com> | 2019-03-05 14:30:03 +0000 |
commit | 740e1b0c6a02a7bec20008758373f0dd80baade4 (patch) | |
tree | eb9795d8013c8eb44a01176979ede6bef600ec3f /lib/virtual_batch.py | |
parent | 802f52a2410ed20cea55e8c097b3875111a80824 (diff) | |
download | e4-gan-740e1b0c6a02a7bec20008758373f0dd80baade4.tar.gz e4-gan-740e1b0c6a02a7bec20008758373f0dd80baade4.tar.bz2 e4-gan-740e1b0c6a02a7bec20008758373f0dd80baade4.zip |
Add virtual_batch support
Diffstat (limited to 'lib/virtual_batch.py')
-rw-r--r-- | lib/virtual_batch.py | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/lib/virtual_batch.py b/lib/virtual_batch.py new file mode 100644 index 0000000..dab0419 --- /dev/null +++ b/lib/virtual_batch.py @@ -0,0 +1,39 @@ +import tensorflow as tf +from tensorflow.keras import backend as K +from tensorflow.keras.layers import Layer +from lib.virtual_batchnorm_impl import VBN +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras import initializers + +class VirtualBatchNormalization(Layer): + def __init__(self, + momentum=0.99, + center=True, + scale=True, + beta_initializer='zeros', + gamma_initializer='ones', + beta_regularizer=None, + gamma_regularizer=None, + **kwargs): + + self.beta_initializer = initializers.get(beta_initializer) + self.gamma_initializer = initializers.get(gamma_initializer) + + super(VirtualBatchNormalization, self).__init__(**kwargs) + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + if not input_shape.ndims: + raise ValueError('Input has undefined rank:', input_shape) + ndims = len(input_shape) + self.input_spec = InputSpec(ndim=ndims) + #super(VirtualBatchNormalization, self).build(input_shape) # Be sure to call this at the end + + def call(self, x): + outputs = VBN(x, gamma_initializer=self.gamma_initializer, beta_initializer=self.beta_initializer)(x) + outputs.set_shape(x.get_shape()) + return outputs + + def compute_output_shape(self, input_shape): + return input_shape |