diff options
Diffstat (limited to 'resnet_preprocessing.py')
-rw-r--r-- | resnet_preprocessing.py | 87 |
1 files changed, 87 insertions, 0 deletions
diff --git a/resnet_preprocessing.py b/resnet_preprocessing.py new file mode 100644 index 0000000..72c799a --- /dev/null +++ b/resnet_preprocessing.py @@ -0,0 +1,87 @@ +# Trimmed by Vasil Zlatanov +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ImageNet preprocessing for ResNet.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +def resize_or_crop_image(image, target_height, target_width): + image_height = tf.shape(image)[0] + image_width = tf.shape(image)[1] + # If the viewport is long but the width is right simply crop the length of the page + # Otherwise we just resize the image bilinearly + image = tf.cond( + tf.logical_and(tf.greater(image_height, target_height),tf.equal(target_width, image_width)), + lambda: tf.cast(tf.image.crop_to_bounding_box(image, 0, 0, target_height, target_width),dtype=tf.float32), + lambda: tf.image.resize_images(image, [target_height,target_width], align_corners=True) + ) + return image + +def preprocess_for_train(image_bytes, target_width, target_height, resize, use_bfloat16): + """Preprocesses the given image for evaluation. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + use_bfloat16: `bool` for whether to use bfloat16. + + Returns: + A preprocessed image `Tensor`. + """ + image = tf.image.decode_png(image_bytes, channels=3) + if resize: + image = resize_or_crop_image(image, target_height, target_width) + else: + image = tf.cast(image, tf.float32) + + return image + + +def preprocess_for_eval(image_bytes, target_width, target_height, resize, use_bfloat16): + """Preprocesses the given image for evaluation. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + use_bfloat16: `bool` for whether to use bfloat16. + + Returns: + A preprocessed image `Tensor`. + """ + image = tf.image.decode_png(image_bytes, channels=3) + if resize: + image = resize_or_crop_image(image, target_height, target_width) + else: + image = tf.cast(image, tf.float32) + + return image + + +def preprocess_image(image_bytes, width, height, resize, is_training=False, use_bfloat16=False): + """Preprocesses the given image. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + is_training: `bool` for whether the preprocessing is for training. + use_bfloat16: `bool` for whether to use bfloat16. + + Returns: + A preprocessed image `Tensor`. + """ + if is_training: + return preprocess_for_train(image_bytes, width, height, resize, use_bfloat16) + else: + return preprocess_for_eval(image_bytes, width, height, resize, use_bfloat16) |