summaryrefslogtreecommitdiff
path: root/resnet_preprocessing.py
diff options
context:
space:
mode:
Diffstat (limited to 'resnet_preprocessing.py')
-rw-r--r--resnet_preprocessing.py87
1 files changed, 87 insertions, 0 deletions
diff --git a/resnet_preprocessing.py b/resnet_preprocessing.py
new file mode 100644
index 0000000..72c799a
--- /dev/null
+++ b/resnet_preprocessing.py
@@ -0,0 +1,87 @@
+# Trimmed by Vasil Zlatanov
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ImageNet preprocessing for ResNet."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+def resize_or_crop_image(image, target_height, target_width):
+ image_height = tf.shape(image)[0]
+ image_width = tf.shape(image)[1]
+ # If the viewport is long but the width is right simply crop the length of the page
+ # Otherwise we just resize the image bilinearly
+ image = tf.cond(
+ tf.logical_and(tf.greater(image_height, target_height),tf.equal(target_width, image_width)),
+ lambda: tf.cast(tf.image.crop_to_bounding_box(image, 0, 0, target_height, target_width),dtype=tf.float32),
+ lambda: tf.image.resize_images(image, [target_height,target_width], align_corners=True)
+ )
+ return image
+
+def preprocess_for_train(image_bytes, target_width, target_height, resize, use_bfloat16):
+ """Preprocesses the given image for evaluation.
+
+ Args:
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
+ use_bfloat16: `bool` for whether to use bfloat16.
+
+ Returns:
+ A preprocessed image `Tensor`.
+ """
+ image = tf.image.decode_png(image_bytes, channels=3)
+ if resize:
+ image = resize_or_crop_image(image, target_height, target_width)
+ else:
+ image = tf.cast(image, tf.float32)
+
+ return image
+
+
+def preprocess_for_eval(image_bytes, target_width, target_height, resize, use_bfloat16):
+ """Preprocesses the given image for evaluation.
+
+ Args:
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
+ use_bfloat16: `bool` for whether to use bfloat16.
+
+ Returns:
+ A preprocessed image `Tensor`.
+ """
+ image = tf.image.decode_png(image_bytes, channels=3)
+ if resize:
+ image = resize_or_crop_image(image, target_height, target_width)
+ else:
+ image = tf.cast(image, tf.float32)
+
+ return image
+
+
+def preprocess_image(image_bytes, width, height, resize, is_training=False, use_bfloat16=False):
+ """Preprocesses the given image.
+
+ Args:
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
+ is_training: `bool` for whether the preprocessing is for training.
+ use_bfloat16: `bool` for whether to use bfloat16.
+
+ Returns:
+ A preprocessed image `Tensor`.
+ """
+ if is_training:
+ return preprocess_for_train(image_bytes, width, height, resize, use_bfloat16)
+ else:
+ return preprocess_for_eval(image_bytes, width, height, resize, use_bfloat16)