diff options
Diffstat (limited to 'util/make-tfrecords.py')
-rwxr-xr-x | util/make-tfrecords.py | 203 |
1 files changed, 203 insertions, 0 deletions
diff --git a/util/make-tfrecords.py b/util/make-tfrecords.py new file mode 100755 index 0000000..f9f3e71 --- /dev/null +++ b/util/make-tfrecords.py @@ -0,0 +1,203 @@ +#!/usr/bin/python +import random +import os +import sys +import math +import tensorflow as tf +import dataset_utils +import numpy as np + +#===============DEFINE YOUR ARGUMENTS============== +flags = tf.app.flags + +#State your dataset directory +flags.DEFINE_string('data', None, 'String: Your dataset directory') + +# The number of images in the validation set. You would have to know the total number of examples in advance. This is essentially your evaluation dataset. +flags.DEFINE_float('validation_size', 0.25, 'Float: The proportion of examples in the dataset to be used for validation') + +# The number of shards per dataset split. +flags.DEFINE_integer('num_shards', 1, 'Int: Number of shards to split the TFRecord files') + +# Seed for repeatability. +flags.DEFINE_integer('random_seed', 0, 'Int: Random seed to use for repeatability.') +flags.DEFINE_bool('overwrite', False, 'Overwrite prevoiusly generated files') + +FLAGS = flags.FLAGS + +class ImageReader(object): + """Helper class that provides TensorFlow image coding utilities.""" + + def __init__(self): + # Initializes function that decodes RGB JPEG data. + self._decode_png_data = tf.placeholder(dtype=tf.string) + self._decode_png = tf.image.decode_png(self._decode_png_data, channels=0) + + def read_image_dims(self, sess, image_data): + image = self.decode_png(sess, image_data) + return image.shape[0], image.shape[1] + + def decode_png(self, sess, image_data): + image = sess.run(self._decode_png, + feed_dict={self._decode_png_data: image_data}) + assert len(image.shape) == 3 + return image + +def _get_filenames_and_classes(data): + """Returns a list of filenames and inferred class names. + + Args: + data: A directory containing a set of subdirectories representing + class names. Each subdirectory should contain PNG or JPG encoded images. + + Returns: + A list of image file paths, relative to `data` and the list of + subdirectories, representing class names. + """ + directories = [] + class_names = [] + for filename in os.listdir(data): + path = os.path.join(data, filename) + if os.path.isdir(path): + print(path) + directories.append(path) + class_names.append(filename) + + photo_filenames = [] + for directory in directories: + for filename in os.listdir(directory): + path = os.path.join(directory, filename) + photo_filenames.append(path) + + return photo_filenames, sorted(class_names) + + +def _get_dataset_filename(data, split_name, shard_id, _NUM_SHARDS): + output_filename = 'websites_%s_%05d-of-%05d.tfrecord' % ( + split_name, shard_id, _NUM_SHARDS) + return os.path.join(data, output_filename) + + +def _convert_dataset(split_name, filenames, class_names_to_ids, data, _NUM_SHARDS): + """Converts the given filenames to a TFRecord dataset. + + Args: + split_name: The name of the dataset, either 'train' or 'validation'. + filenames: A list of absolute paths to png or jpg images. + class_names_to_ids: A dictionary from class names (strings) to ids + (integers). + data: The directory where the converted datasets are stored. + """ + assert split_name in ['train', 'validation'] + + failed = 0 + success = 0 + # class_cnts is used for balancing training through class_weights + class_cnts = [0] * len(class_names_to_ids) + num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS))) + + with tf.Graph().as_default(): + image_reader = ImageReader() + + with tf.Session('') as sess: + + for shard_id in range(_NUM_SHARDS): + output_filename = _get_dataset_filename( + data, split_name, shard_id, _NUM_SHARDS) + + with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: + start_ndx = shard_id * num_per_shard + end_ndx = min((shard_id+1) * num_per_shard, len(filenames)) + for i in range(start_ndx, end_ndx): +# sys.stdout.write('\r>> Converting image %d/%d shard %d: %s' % ( +# i+1, len(filenames), shard_id, filenames[i])) +# sys.stdout.flush() + + # Read the filename: + image_data = tf.gfile.FastGFile(filenames[i], 'rb').read() + try: + height, width = image_reader.read_image_dims(sess, image_data) + class_name = os.path.basename(os.path.dirname(filenames[i])) + class_id = class_names_to_ids[class_name] + + example = dataset_utils.image_to_tfexample( + image_data, b'png', height, width, class_id) + tfrecord_writer.write(example.SerializeToString()) + success += 1; + class_cnts[class_id] += 1; + except: + failed = failed + 1; + + + + + sys.stdout.write('%d in total failed!\n' % failed) + sys.stdout.write('%d in total were written successfuly!\n' % success) + sys.stdout.flush() + return class_cnts + + +def _dataset_exists(data, _NUM_SHARDS): + for split_name in ['train', 'validation']: + for shard_id in range(_NUM_SHARDS): + output_filename = _get_dataset_filename( + data, split_name, shard_id, _NUM_SHARDS) + if not tf.gfile.Exists(output_filename): + return False + return True + +def main(): + + #=============CHECKS============== + #Check if there is a dataset directory entered + if not FLAGS.data: + raise ValueError('data is empty. Please state a data argument.') + + #If the TFRecord files already exist in the directory, then exit without creating the files again + if not FLAGS.overwrite and _dataset_exists(data = FLAGS.data, _NUM_SHARDS = FLAGS.num_shards): + print('Dataset files already exist. Exiting without re-creating them.') + print('Use --overwrite flag or remove them') + return None + #==========END OF CHECKS============ + + #Get a list of photo_filenames like ['123.jpg', '456.jpg'...] and a list of sorted class names from parsing the subdirectories. + photo_filenames, class_names = _get_filenames_and_classes(FLAGS.data) + + #Refer each of the class name to a specific integer number for predictions later + class_names_to_ids = dict(zip(class_names, range(len(class_names)))) + + #Find the number of validation examples we need + num_validation = int(FLAGS.validation_size * len(photo_filenames)) + + # Divide the training datasets into train and test: + random.seed(FLAGS.random_seed) + random.shuffle(photo_filenames) + training_filenames = photo_filenames[num_validation:] + validation_filenames = photo_filenames[:num_validation] + + # First, convert the training and validation sets. + train_cnts = _convert_dataset('train', training_filenames, class_names_to_ids, + data = FLAGS.data, _NUM_SHARDS = 1) + val_cnts = _convert_dataset('validation', validation_filenames, class_names_to_ids, + data = FLAGS.data, _NUM_SHARDS = 1) + + # Finally, write the labels file: + labels_to_class_names = dict(zip(range(len(class_names)), class_names)) + dataset_utils.write_label_file(labels_to_class_names, FLAGS.data) + + total_train_cnt = sum(train_cnts) + class_cnt = len(train_cnts) + class_weights = [ total_train_cnt/(train_cnts[i]*class_cnt+1e-10) for i in range(class_cnt) ] + + data_info = os.path.join(FLAGS.data, 'dinfo.npz') + np.savez(data_info, train_cnt=total_train_cnt, + val_cnt=sum(val_cnts), + class_weights=class_weights, + classes=class_names + ) + + print('\nFinished converting the dataset!') + +if __name__ == "__main__": + main() + |