diff options
Diffstat (limited to 'imagenet_input.py')
-rw-r--r-- | imagenet_input.py | 166 |
1 files changed, 166 insertions, 0 deletions
diff --git a/imagenet_input.py b/imagenet_input.py new file mode 100644 index 0000000..55729b2 --- /dev/null +++ b/imagenet_input.py @@ -0,0 +1,166 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Efficient ImageNet input pipeline using tf.data.Dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import tensorflow as tf + +import resnet_preprocessing + +class ImageNetInput(object): + """Generates ImageNet input_fn for training or evaluation. + + The training data is assumed to be in TFRecord format with keys as specified + in the dataset_parser below, sharded across 1024 files, named sequentially: + train-00000-of-01024 + train-00001-of-01024 + ... + train-01023-of-01024 + + The validation data is in the same format but sharded in 128 files. + + The format of the data required is created by the script at: + https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py + + Args: + is_training: `bool` for whether the input is for training. + data_dir: `str` for the directory of the training and validation data; + if 'null' (the literal string 'null', not None), then construct a null + pipeline, consisting of empty images. + use_bfloat16: If True, use bfloat16 precision; else use float32. + per_core_batch_size: The per-TPU-core batch size to use. + """ + + def __init__(self, + width, + height, + resize, + is_training, + data_dir, + use_bfloat16=False, + per_core_batch_size=128): + self.image_preprocessing_fn = resnet_preprocessing.preprocess_image + self.is_training = is_training + self.width = width + self.height = height + self.resize = resize + self.use_bfloat16 = use_bfloat16 + self.data_dir = data_dir + if self.data_dir == 'null' or self.data_dir == '': + self.data_dir = None + self.per_core_batch_size = per_core_batch_size + + def dataset_parser(self, value): + """Parse an ImageNet record from a serialized string Tensor.""" + keys_to_features = { + 'image/encoded': + tf.FixedLenFeature((), tf.string, ''), + 'image/format': + tf.FixedLenFeature((), tf.string, 'png'), + 'image/class/label': + tf.FixedLenFeature([], tf.int64, -1), + 'image/height': + tf.FixedLenFeature([], tf.int64, -2), + 'image/width': + tf.FixedLenFeature([], tf.int64, -3), + + } + + parsed = tf.parse_single_example(value, keys_to_features) + image_bytes = tf.reshape(parsed['image/encoded'], shape=[]) + + image = self.image_preprocessing_fn( + image_bytes, + width=self.width, height=self.height, + resize=self.resize, + is_training=self.is_training, + use_bfloat16=self.use_bfloat16, + ) + + # Subtract one so that labels are in [0, 1000), and cast to float32 for + # Keras model. + label = tf.cast(tf.cast( + tf.reshape(parsed['image/class/label'], shape=[1]), dtype=tf.int32), # - 1, + dtype=tf.float32) + + return image, label + + def input_fn(self): + """Input function which provides a single batch for train or eval. + + Returns: + A `tf.data.Dataset` object. + """ + # Shuffle the filenames to ensure better randomization. + file_pattern = os.path.join( + self.data_dir, 'websites_train*' if self.is_training else 'websites_validation*') + dataset = tf.data.Dataset.list_files(file_pattern, shuffle=self.is_training) + + if self.is_training: + dataset = dataset.repeat() + + def fetch_dataset(filename): + buffer_size = 100 * 1024 * 1024 # 100 MiB per file + dataset = tf.data.TFRecordDataset(filename, buffer_size=buffer_size) + return dataset + + # Read the data from disk in parallel + dataset = dataset.interleave(fetch_dataset, cycle_length=16) + + if self.is_training: + dataset = dataset.shuffle(1024) + + # Parse, pre-process, and batch the data in parallel + dataset = dataset.apply( + tf.data.experimental.map_and_batch( + self.dataset_parser, + batch_size=self.per_core_batch_size, + num_parallel_batches=2, + drop_remainder=True)) + + # Prefetch overlaps in-feed with training + dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) + return dataset + + # TODO(xiejw): Remove this generator when we have support for top_k + # evaluation. + def evaluation_generator(self, sess): + """Creates a generator for evaluation.""" + next_batch = self.input_fn().make_one_shot_iterator().get_next() + while True: + try: + yield sess.run(next_batch) + except tf.errors.OutOfRangeError: + return + + def input_fn_null(self): + """Input function which provides null (black) images.""" + dataset = tf.data.Dataset.range(1).repeat().map(self._get_null_input) + dataset = dataset.prefetch(self.per_core_batch_size) + + dataset = dataset.batch(self.per_core_batch_size, drop_remainder=True) + + dataset = dataset.prefetch(32) # Prefetch overlaps in-feed with training + tf.logging.info('Input dataset: %s', str(dataset)) + return dataset + + def _get_null_input(self, _): + null_image = tf.zeros([320, 240, 3], tf.float32) + return null_image, tf.constant(0, tf.float32) |