# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Efficient ImageNet input pipeline using tf.data.Dataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import tensorflow as tf import resnet_preprocessing class ImageNetInput(object): """Generates ImageNet input_fn for training or evaluation. The training data is assumed to be in TFRecord format with keys as specified in the dataset_parser below, sharded across 1024 files, named sequentially: train-00000-of-01024 train-00001-of-01024 ... train-01023-of-01024 The validation data is in the same format but sharded in 128 files. The format of the data required is created by the script at: https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py Args: is_training: `bool` for whether the input is for training. data_dir: `str` for the directory of the training and validation data; if 'null' (the literal string 'null', not None), then construct a null pipeline, consisting of empty images. use_bfloat16: If True, use bfloat16 precision; else use float32. per_core_batch_size: The per-TPU-core batch size to use. """ def __init__(self, width, height, resize, is_training, data_dir, use_bfloat16=False, per_core_batch_size=128): self.image_preprocessing_fn = resnet_preprocessing.preprocess_image self.is_training = is_training self.width = width self.height = height self.resize = resize self.use_bfloat16 = use_bfloat16 self.data_dir = data_dir if self.data_dir == 'null' or self.data_dir == '': self.data_dir = None self.per_core_batch_size = per_core_batch_size def dataset_parser(self, value): """Parse an ImageNet record from a serialized string Tensor.""" keys_to_features = { 'image/encoded': tf.FixedLenFeature((), tf.string, ''), 'image/format': tf.FixedLenFeature((), tf.string, 'png'), 'image/class/label': tf.FixedLenFeature([], tf.int64, -1), 'image/height': tf.FixedLenFeature([], tf.int64, -2), 'image/width': tf.FixedLenFeature([], tf.int64, -3), } parsed = tf.parse_single_example(value, keys_to_features) image_bytes = tf.reshape(parsed['image/encoded'], shape=[]) image = self.image_preprocessing_fn( image_bytes, width=self.width, height=self.height, resize=self.resize, is_training=self.is_training, use_bfloat16=self.use_bfloat16, ) # Subtract one so that labels are in [0, 1000), and cast to float32 for # Keras model. label = tf.cast(tf.cast( tf.reshape(parsed['image/class/label'], shape=[1]), dtype=tf.int32), # - 1, dtype=tf.float32) return image, label def input_fn(self): """Input function which provides a single batch for train or eval. Returns: A `tf.data.Dataset` object. """ # Shuffle the filenames to ensure better randomization. file_pattern = os.path.join( self.data_dir, 'websites_train*' if self.is_training else 'websites_validation*') dataset = tf.data.Dataset.list_files(file_pattern, shuffle=self.is_training) if self.is_training: dataset = dataset.repeat() def fetch_dataset(filename): buffer_size = 100 * 1024 * 1024 # 100 MiB per file dataset = tf.data.TFRecordDataset(filename, buffer_size=buffer_size) return dataset # Read the data from disk in parallel dataset = dataset.interleave(fetch_dataset, cycle_length=16) if self.is_training: dataset = dataset.shuffle(1024) # Parse, pre-process, and batch the data in parallel dataset = dataset.apply( tf.data.experimental.map_and_batch( self.dataset_parser, batch_size=self.per_core_batch_size, num_parallel_batches=2, drop_remainder=True)) # Prefetch overlaps in-feed with training dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset # TODO(xiejw): Remove this generator when we have support for top_k # evaluation. def evaluation_generator(self, sess): """Creates a generator for evaluation.""" next_batch = self.input_fn().make_one_shot_iterator().get_next() while True: try: yield sess.run(next_batch) except tf.errors.OutOfRangeError: return def input_fn_null(self): """Input function which provides null (black) images.""" dataset = tf.data.Dataset.range(1).repeat().map(self._get_null_input) dataset = dataset.prefetch(self.per_core_batch_size) dataset = dataset.batch(self.per_core_batch_size, drop_remainder=True) dataset = dataset.prefetch(32) # Prefetch overlaps in-feed with training tf.logging.info('Input dataset: %s', str(dataset)) return dataset def _get_null_input(self, _): null_image = tf.zeros([320, 240, 3], tf.float32) return null_image, tf.constant(0, tf.float32)