summaryrefslogtreecommitdiff
path: root/util/make-tfrecords.py
diff options
context:
space:
mode:
Diffstat (limited to 'util/make-tfrecords.py')
-rwxr-xr-xutil/make-tfrecords.py203
1 files changed, 203 insertions, 0 deletions
diff --git a/util/make-tfrecords.py b/util/make-tfrecords.py
new file mode 100755
index 0000000..f9f3e71
--- /dev/null
+++ b/util/make-tfrecords.py
@@ -0,0 +1,203 @@
+#!/usr/bin/python
+import random
+import os
+import sys
+import math
+import tensorflow as tf
+import dataset_utils
+import numpy as np
+
+#===============DEFINE YOUR ARGUMENTS==============
+flags = tf.app.flags
+
+#State your dataset directory
+flags.DEFINE_string('data', None, 'String: Your dataset directory')
+
+# The number of images in the validation set. You would have to know the total number of examples in advance. This is essentially your evaluation dataset.
+flags.DEFINE_float('validation_size', 0.25, 'Float: The proportion of examples in the dataset to be used for validation')
+
+# The number of shards per dataset split.
+flags.DEFINE_integer('num_shards', 1, 'Int: Number of shards to split the TFRecord files')
+
+# Seed for repeatability.
+flags.DEFINE_integer('random_seed', 0, 'Int: Random seed to use for repeatability.')
+flags.DEFINE_bool('overwrite', False, 'Overwrite prevoiusly generated files')
+
+FLAGS = flags.FLAGS
+
+class ImageReader(object):
+ """Helper class that provides TensorFlow image coding utilities."""
+
+ def __init__(self):
+ # Initializes function that decodes RGB JPEG data.
+ self._decode_png_data = tf.placeholder(dtype=tf.string)
+ self._decode_png = tf.image.decode_png(self._decode_png_data, channels=0)
+
+ def read_image_dims(self, sess, image_data):
+ image = self.decode_png(sess, image_data)
+ return image.shape[0], image.shape[1]
+
+ def decode_png(self, sess, image_data):
+ image = sess.run(self._decode_png,
+ feed_dict={self._decode_png_data: image_data})
+ assert len(image.shape) == 3
+ return image
+
+def _get_filenames_and_classes(data):
+ """Returns a list of filenames and inferred class names.
+
+ Args:
+ data: A directory containing a set of subdirectories representing
+ class names. Each subdirectory should contain PNG or JPG encoded images.
+
+ Returns:
+ A list of image file paths, relative to `data` and the list of
+ subdirectories, representing class names.
+ """
+ directories = []
+ class_names = []
+ for filename in os.listdir(data):
+ path = os.path.join(data, filename)
+ if os.path.isdir(path):
+ print(path)
+ directories.append(path)
+ class_names.append(filename)
+
+ photo_filenames = []
+ for directory in directories:
+ for filename in os.listdir(directory):
+ path = os.path.join(directory, filename)
+ photo_filenames.append(path)
+
+ return photo_filenames, sorted(class_names)
+
+
+def _get_dataset_filename(data, split_name, shard_id, _NUM_SHARDS):
+ output_filename = 'websites_%s_%05d-of-%05d.tfrecord' % (
+ split_name, shard_id, _NUM_SHARDS)
+ return os.path.join(data, output_filename)
+
+
+def _convert_dataset(split_name, filenames, class_names_to_ids, data, _NUM_SHARDS):
+ """Converts the given filenames to a TFRecord dataset.
+
+ Args:
+ split_name: The name of the dataset, either 'train' or 'validation'.
+ filenames: A list of absolute paths to png or jpg images.
+ class_names_to_ids: A dictionary from class names (strings) to ids
+ (integers).
+ data: The directory where the converted datasets are stored.
+ """
+ assert split_name in ['train', 'validation']
+
+ failed = 0
+ success = 0
+ # class_cnts is used for balancing training through class_weights
+ class_cnts = [0] * len(class_names_to_ids)
+ num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))
+
+ with tf.Graph().as_default():
+ image_reader = ImageReader()
+
+ with tf.Session('') as sess:
+
+ for shard_id in range(_NUM_SHARDS):
+ output_filename = _get_dataset_filename(
+ data, split_name, shard_id, _NUM_SHARDS)
+
+ with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
+ start_ndx = shard_id * num_per_shard
+ end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
+ for i in range(start_ndx, end_ndx):
+# sys.stdout.write('\r>> Converting image %d/%d shard %d: %s' % (
+# i+1, len(filenames), shard_id, filenames[i]))
+# sys.stdout.flush()
+
+ # Read the filename:
+ image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
+ try:
+ height, width = image_reader.read_image_dims(sess, image_data)
+ class_name = os.path.basename(os.path.dirname(filenames[i]))
+ class_id = class_names_to_ids[class_name]
+
+ example = dataset_utils.image_to_tfexample(
+ image_data, b'png', height, width, class_id)
+ tfrecord_writer.write(example.SerializeToString())
+ success += 1;
+ class_cnts[class_id] += 1;
+ except:
+ failed = failed + 1;
+
+
+
+
+ sys.stdout.write('%d in total failed!\n' % failed)
+ sys.stdout.write('%d in total were written successfuly!\n' % success)
+ sys.stdout.flush()
+ return class_cnts
+
+
+def _dataset_exists(data, _NUM_SHARDS):
+ for split_name in ['train', 'validation']:
+ for shard_id in range(_NUM_SHARDS):
+ output_filename = _get_dataset_filename(
+ data, split_name, shard_id, _NUM_SHARDS)
+ if not tf.gfile.Exists(output_filename):
+ return False
+ return True
+
+def main():
+
+ #=============CHECKS==============
+ #Check if there is a dataset directory entered
+ if not FLAGS.data:
+ raise ValueError('data is empty. Please state a data argument.')
+
+ #If the TFRecord files already exist in the directory, then exit without creating the files again
+ if not FLAGS.overwrite and _dataset_exists(data = FLAGS.data, _NUM_SHARDS = FLAGS.num_shards):
+ print('Dataset files already exist. Exiting without re-creating them.')
+ print('Use --overwrite flag or remove them')
+ return None
+ #==========END OF CHECKS============
+
+ #Get a list of photo_filenames like ['123.jpg', '456.jpg'...] and a list of sorted class names from parsing the subdirectories.
+ photo_filenames, class_names = _get_filenames_and_classes(FLAGS.data)
+
+ #Refer each of the class name to a specific integer number for predictions later
+ class_names_to_ids = dict(zip(class_names, range(len(class_names))))
+
+ #Find the number of validation examples we need
+ num_validation = int(FLAGS.validation_size * len(photo_filenames))
+
+ # Divide the training datasets into train and test:
+ random.seed(FLAGS.random_seed)
+ random.shuffle(photo_filenames)
+ training_filenames = photo_filenames[num_validation:]
+ validation_filenames = photo_filenames[:num_validation]
+
+ # First, convert the training and validation sets.
+ train_cnts = _convert_dataset('train', training_filenames, class_names_to_ids,
+ data = FLAGS.data, _NUM_SHARDS = 1)
+ val_cnts = _convert_dataset('validation', validation_filenames, class_names_to_ids,
+ data = FLAGS.data, _NUM_SHARDS = 1)
+
+ # Finally, write the labels file:
+ labels_to_class_names = dict(zip(range(len(class_names)), class_names))
+ dataset_utils.write_label_file(labels_to_class_names, FLAGS.data)
+
+ total_train_cnt = sum(train_cnts)
+ class_cnt = len(train_cnts)
+ class_weights = [ total_train_cnt/(train_cnts[i]*class_cnt+1e-10) for i in range(class_cnt) ]
+
+ data_info = os.path.join(FLAGS.data, 'dinfo.npz')
+ np.savez(data_info, train_cnt=total_train_cnt,
+ val_cnt=sum(val_cnts),
+ class_weights=class_weights,
+ classes=class_names
+ )
+
+ print('\nFinished converting the dataset!')
+
+if __name__ == "__main__":
+ main()
+