diff options
Diffstat (limited to 'logo_input.py')
-rw-r--r-- | logo_input.py | 143 |
1 files changed, 143 insertions, 0 deletions
diff --git a/logo_input.py b/logo_input.py new file mode 100644 index 0000000..1e017a4 --- /dev/null +++ b/logo_input.py @@ -0,0 +1,143 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Efficient ImageNet input pipeline using tf.data.Dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import tensorflow as tf + +import resnet_preprocessing + +class ImageNetInput(object): + def __init__(self, + width, + height, + resize, + is_training, + data_dir, + use_bfloat16=False, + per_core_batch_size=128): + self.image_preprocessing_fn = resnet_preprocessing.preprocess_image + self.is_training = is_training + self.width = width + self.height = height + self.resize = resize + self.use_bfloat16 = use_bfloat16 + self.data_dir = data_dir + if self.data_dir == 'null' or self.data_dir == '': + self.data_dir = None + self.per_core_batch_size = per_core_batch_size + + def dataset_parser(self, value): + """Parse an ImageNet record from a serialized string Tensor.""" + keys_to_features = { + 'image/encoded': + tf.FixedLenFeature((), tf.string, ''), + 'image/format': + tf.FixedLenFeature((), tf.string, 'png'), + 'image/class/label': + tf.FixedLenFeature([], tf.int64, -1), + 'image/height': + tf.FixedLenFeature([], tf.int64, -2), + 'image/width': + tf.FixedLenFeature([], tf.int64, -3), + + } + + parsed = tf.parse_single_example(value, keys_to_features) + image_bytes = tf.reshape(parsed['image/encoded'], shape=[]) + + image = self.image_preprocessing_fn( + image_bytes, + width=self.width, height=self.height, + resize=self.resize, + is_training=self.is_training, + use_bfloat16=self.use_bfloat16, + ) + + # Subtract one so that labels are in [0, 1000), and cast to float32 for + # Keras model. + label = tf.cast(tf.cast( + tf.reshape(parsed['image/class/label'], shape=[1]), dtype=tf.int32), # - 1, + dtype=tf.float32) + + return image, label + + def input_fn(self): + """Input function which provides a single batch for train or eval. + + Returns: + A `tf.data.Dataset` object. + """ + # Shuffle the filenames to ensure better randomization. + file_pattern = os.path.join( + self.data_dir, 'websites_train*' if self.is_training else 'websites_validation*') + dataset = tf.data.Dataset.list_files(file_pattern, shuffle=self.is_training) + + if self.is_training: + dataset = dataset.repeat() + + def fetch_dataset(filename): + buffer_size = 100 * 1024 * 1024 # 100 MiB per file + dataset = tf.data.TFRecordDataset(filename, buffer_size=buffer_size) + return dataset + + # Read the data from disk in parallel + dataset = dataset.interleave(fetch_dataset, cycle_length=16) + + if self.is_training: + dataset = dataset.shuffle(1024) + + # Parse, pre-process, and batch the data in parallel + dataset = dataset.apply( + tf.data.experimental.map_and_batch( + self.dataset_parser, + batch_size=self.per_core_batch_size, + num_parallel_batches=2, + drop_remainder=True)) + + # Prefetch overlaps in-feed with training + dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) + return dataset + + # TODO(xiejw): Remove this generator when we have support for top_k + # evaluation. + def evaluation_generator(self, sess): + """Creates a generator for evaluation.""" + next_batch = self.input_fn().make_one_shot_iterator().get_next() + while True: + try: + yield sess.run(next_batch) + except tf.errors.OutOfRangeError: + return + + def input_fn_null(self): + """Input function which provides null (black) images.""" + dataset = tf.data.Dataset.range(1).repeat().map(self._get_null_input) + dataset = dataset.prefetch(self.per_core_batch_size) + + dataset = dataset.batch(self.per_core_batch_size, drop_remainder=True) + + dataset = dataset.prefetch(32) # Prefetch overlaps in-feed with training + tf.logging.info('Input dataset: %s', str(dataset)) + return dataset + + def _get_null_input(self, _): + null_image = tf.zeros([320, 240, 3], tf.float32) + return null_image, tf.constant(0, tf.float32) |