# Trimmed by Vasil Zlatanov # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ImageNet preprocessing for ResNet.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf def resize_or_crop_image(image, target_height, target_width): image_height = tf.shape(image)[0] image_width = tf.shape(image)[1] # If the viewport is long but the width is right simply crop the length of the page # Otherwise we just resize the image bilinearly image = tf.cond( tf.logical_and(tf.greater(image_height, target_height),tf.equal(target_width, image_width)), lambda: tf.cast(tf.image.crop_to_bounding_box(image, 0, 0, target_height, target_width),dtype=tf.float32), lambda: tf.image.resize_images(image, [target_height,target_width], align_corners=True) ) return image def preprocess_for_train(image_bytes, target_width, target_height, resize, use_bfloat16): """Preprocesses the given image for evaluation. Args: image_bytes: `Tensor` representing an image binary of arbitrary size. use_bfloat16: `bool` for whether to use bfloat16. Returns: A preprocessed image `Tensor`. """ image = tf.image.decode_png(image_bytes, channels=3) if resize: image = resize_or_crop_image(image, target_height, target_width) else: image = tf.cast(image, tf.float32) return image def preprocess_for_eval(image_bytes, target_width, target_height, resize, use_bfloat16): """Preprocesses the given image for evaluation. Args: image_bytes: `Tensor` representing an image binary of arbitrary size. use_bfloat16: `bool` for whether to use bfloat16. Returns: A preprocessed image `Tensor`. """ image = tf.image.decode_png(image_bytes, channels=3) if resize: image = resize_or_crop_image(image, target_height, target_width) else: image = tf.cast(image, tf.float32) return image def preprocess_image(image_bytes, width, height, resize, is_training=False, use_bfloat16=False): """Preprocesses the given image. Args: image_bytes: `Tensor` representing an image binary of arbitrary size. is_training: `bool` for whether the preprocessing is for training. use_bfloat16: `bool` for whether to use bfloat16. Returns: A preprocessed image `Tensor`. """ if is_training: return preprocess_for_train(image_bytes, width, height, resize, use_bfloat16) else: return preprocess_for_eval(image_bytes, width, height, resize, use_bfloat16)