#!/usr/bin/python import random import os import sys import math import tensorflow as tf import dataset_utils import numpy as np #===============DEFINE YOUR ARGUMENTS============== flags = tf.app.flags #State your dataset directory flags.DEFINE_string('data', None, 'String: Your dataset directory') # The number of images in the validation set. You would have to know the total number of examples in advance. This is essentially your evaluation dataset. flags.DEFINE_float('validation_size', 0.25, 'Float: The proportion of examples in the dataset to be used for validation') # The number of shards per dataset split. flags.DEFINE_integer('num_shards', 1, 'Int: Number of shards to split the TFRecord files') # Seed for repeatability. flags.DEFINE_integer('random_seed', 0, 'Int: Random seed to use for repeatability.') flags.DEFINE_bool('overwrite', False, 'Overwrite prevoiusly generated files') FLAGS = flags.FLAGS class ImageReader(object): """Helper class that provides TensorFlow image coding utilities.""" def __init__(self): # Initializes function that decodes RGB JPEG data. self._decode_png_data = tf.placeholder(dtype=tf.string) self._decode_png = tf.image.decode_png(self._decode_png_data, channels=0) def read_image_dims(self, sess, image_data): image = self.decode_png(sess, image_data) return image.shape[0], image.shape[1] def decode_png(self, sess, image_data): image = sess.run(self._decode_png, feed_dict={self._decode_png_data: image_data}) assert len(image.shape) == 3 return image def _get_filenames_and_classes(data): """Returns a list of filenames and inferred class names. Args: data: A directory containing a set of subdirectories representing class names. Each subdirectory should contain PNG or JPG encoded images. Returns: A list of image file paths, relative to `data` and the list of subdirectories, representing class names. """ directories = [] class_names = [] for filename in os.listdir(data): path = os.path.join(data, filename) if os.path.isdir(path): print(path) directories.append(path) class_names.append(filename) photo_filenames = [] for directory in directories: for filename in os.listdir(directory): path = os.path.join(directory, filename) photo_filenames.append(path) return photo_filenames, sorted(class_names) def _get_dataset_filename(data, split_name, shard_id, _NUM_SHARDS): output_filename = 'websites_%s_%05d-of-%05d.tfrecord' % ( split_name, shard_id, _NUM_SHARDS) return os.path.join(data, output_filename) def _convert_dataset(split_name, filenames, class_names_to_ids, data, _NUM_SHARDS): """Converts the given filenames to a TFRecord dataset. Args: split_name: The name of the dataset, either 'train' or 'validation'. filenames: A list of absolute paths to png or jpg images. class_names_to_ids: A dictionary from class names (strings) to ids (integers). data: The directory where the converted datasets are stored. """ assert split_name in ['train', 'validation'] failed = 0 success = 0 # class_cnts is used for balancing training through class_weights class_cnts = [0] * len(class_names_to_ids) num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS))) with tf.Graph().as_default(): image_reader = ImageReader() with tf.Session('') as sess: for shard_id in range(_NUM_SHARDS): output_filename = _get_dataset_filename( data, split_name, shard_id, _NUM_SHARDS) with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: start_ndx = shard_id * num_per_shard end_ndx = min((shard_id+1) * num_per_shard, len(filenames)) for i in range(start_ndx, end_ndx): # sys.stdout.write('\r>> Converting image %d/%d shard %d: %s' % ( # i+1, len(filenames), shard_id, filenames[i])) # sys.stdout.flush() # Read the filename: image_data = tf.gfile.FastGFile(filenames[i], 'rb').read() try: height, width = image_reader.read_image_dims(sess, image_data) class_name = os.path.basename(os.path.dirname(filenames[i])) class_id = class_names_to_ids[class_name] example = dataset_utils.image_to_tfexample( image_data, b'png', height, width, class_id) tfrecord_writer.write(example.SerializeToString()) success += 1; class_cnts[class_id] += 1; except: failed = failed + 1; sys.stdout.write('%d in total failed!\n' % failed) sys.stdout.write('%d in total were written successfuly!\n' % success) sys.stdout.flush() return class_cnts def _dataset_exists(data, _NUM_SHARDS): for split_name in ['train', 'validation']: for shard_id in range(_NUM_SHARDS): output_filename = _get_dataset_filename( data, split_name, shard_id, _NUM_SHARDS) if not tf.gfile.Exists(output_filename): return False return True def main(): #=============CHECKS============== #Check if there is a dataset directory entered if not FLAGS.data: raise ValueError('data is empty. Please state a data argument.') #If the TFRecord files already exist in the directory, then exit without creating the files again if not FLAGS.overwrite and _dataset_exists(data = FLAGS.data, _NUM_SHARDS = FLAGS.num_shards): print('Dataset files already exist. Exiting without re-creating them.') print('Use --overwrite flag or remove them') return None #==========END OF CHECKS============ #Get a list of photo_filenames like ['123.jpg', '456.jpg'...] and a list of sorted class names from parsing the subdirectories. photo_filenames, class_names = _get_filenames_and_classes(FLAGS.data) #Refer each of the class name to a specific integer number for predictions later class_names_to_ids = dict(zip(class_names, range(len(class_names)))) #Find the number of validation examples we need num_validation = int(FLAGS.validation_size * len(photo_filenames)) # Divide the training datasets into train and test: random.seed(FLAGS.random_seed) random.shuffle(photo_filenames) training_filenames = photo_filenames[num_validation:] validation_filenames = photo_filenames[:num_validation] # First, convert the training and validation sets. train_cnts = _convert_dataset('train', training_filenames, class_names_to_ids, data = FLAGS.data, _NUM_SHARDS = 1) val_cnts = _convert_dataset('validation', validation_filenames, class_names_to_ids, data = FLAGS.data, _NUM_SHARDS = 1) # Finally, write the labels file: labels_to_class_names = dict(zip(range(len(class_names)), class_names)) dataset_utils.write_label_file(labels_to_class_names, FLAGS.data) total_train_cnt = sum(train_cnts) class_cnt = len(train_cnts) class_weights = [ total_train_cnt/(train_cnts[i]*class_cnt+1e-10) for i in range(class_cnt) ] data_info = os.path.join(FLAGS.data, 'dinfo.npz') np.savez(data_info, train_cnt=total_train_cnt, val_cnt=sum(val_cnts), class_weights=class_weights, classes=class_names ) print('\nFinished converting the dataset!') if __name__ == "__main__": main()