summaryrefslogtreecommitdiff
path: root/util/make-tfrecords.py
blob: f9f3e71ffac3384e09b9a9f736817082f99e0e8f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
#!/usr/bin/python
import random
import os
import sys
import math
import tensorflow as tf
import dataset_utils
import numpy as np

#===============DEFINE YOUR ARGUMENTS==============
flags = tf.app.flags

#State your dataset directory
flags.DEFINE_string('data', None, 'String: Your dataset directory')

# The number of images in the validation set. You would have to know the total number of examples in advance. This is essentially your evaluation dataset.
flags.DEFINE_float('validation_size', 0.25, 'Float: The proportion of examples in the dataset to be used for validation')

# The number of shards per dataset split.
flags.DEFINE_integer('num_shards', 1, 'Int: Number of shards to split the TFRecord files')

# Seed for repeatability.
flags.DEFINE_integer('random_seed', 0, 'Int: Random seed to use for repeatability.')
flags.DEFINE_bool('overwrite', False, 'Overwrite prevoiusly generated files')

FLAGS = flags.FLAGS

class ImageReader(object):
  """Helper class that provides TensorFlow image coding utilities."""

  def __init__(self):
    # Initializes function that decodes RGB JPEG data.
    self._decode_png_data = tf.placeholder(dtype=tf.string)
    self._decode_png = tf.image.decode_png(self._decode_png_data, channels=0)

  def read_image_dims(self, sess, image_data):
    image = self.decode_png(sess, image_data)
    return image.shape[0], image.shape[1]

  def decode_png(self, sess, image_data):
    image = sess.run(self._decode_png,
                     feed_dict={self._decode_png_data: image_data})
    assert len(image.shape) == 3
    return image

def _get_filenames_and_classes(data):
  """Returns a list of filenames and inferred class names.

  Args:
    data: A directory containing a set of subdirectories representing
      class names. Each subdirectory should contain PNG or JPG encoded images.

  Returns:
    A list of image file paths, relative to `data` and the list of
    subdirectories, representing class names.
  """
  directories = []
  class_names = []
  for filename in os.listdir(data):
    path = os.path.join(data, filename)
    if os.path.isdir(path):
      print(path)
      directories.append(path)
      class_names.append(filename)

  photo_filenames = []
  for directory in directories:
    for filename in os.listdir(directory):
      path = os.path.join(directory, filename)
      photo_filenames.append(path)

  return photo_filenames, sorted(class_names)


def _get_dataset_filename(data, split_name, shard_id, _NUM_SHARDS):
  output_filename = 'websites_%s_%05d-of-%05d.tfrecord' % (
      split_name, shard_id, _NUM_SHARDS)
  return os.path.join(data, output_filename)


def _convert_dataset(split_name, filenames, class_names_to_ids, data, _NUM_SHARDS):
  """Converts the given filenames to a TFRecord dataset.

  Args:
    split_name: The name of the dataset, either 'train' or 'validation'.
    filenames: A list of absolute paths to png or jpg images.
    class_names_to_ids: A dictionary from class names (strings) to ids
      (integers).
    data: The directory where the converted datasets are stored.
  """
  assert split_name in ['train', 'validation']

  failed = 0
  success = 0
  # class_cnts is used for balancing training through class_weights
  class_cnts = [0] * len(class_names_to_ids)
  num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))

  with tf.Graph().as_default():
    image_reader = ImageReader()

    with tf.Session('') as sess:

      for shard_id in range(_NUM_SHARDS):
        output_filename = _get_dataset_filename(
            data, split_name, shard_id, _NUM_SHARDS)

        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
          start_ndx = shard_id * num_per_shard
          end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
          for i in range(start_ndx, end_ndx):
#            sys.stdout.write('\r>> Converting image %d/%d shard %d: %s' % (
#                i+1, len(filenames), shard_id, filenames[i]))
#            sys.stdout.flush()

            # Read the filename:
            image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
            try:
                height, width = image_reader.read_image_dims(sess, image_data)
                class_name = os.path.basename(os.path.dirname(filenames[i]))
                class_id = class_names_to_ids[class_name]

                example = dataset_utils.image_to_tfexample(
                    image_data, b'png', height, width, class_id)
                tfrecord_writer.write(example.SerializeToString())
                success += 1;
                class_cnts[class_id] += 1;
            except:
                failed = failed + 1;



  
  sys.stdout.write('%d in total failed!\n' % failed)
  sys.stdout.write('%d in total were written successfuly!\n' % success)
  sys.stdout.flush()
  return class_cnts


def _dataset_exists(data, _NUM_SHARDS):
  for split_name in ['train', 'validation']:
    for shard_id in range(_NUM_SHARDS):
      output_filename = _get_dataset_filename(
          data, split_name, shard_id, _NUM_SHARDS)
      if not tf.gfile.Exists(output_filename):
        return False
  return True

def main():

    #=============CHECKS==============
    #Check if there is a dataset directory entered
    if not FLAGS.data:
        raise ValueError('data is empty. Please state a data argument.')

    #If the TFRecord files already exist in the directory, then exit without creating the files again
    if not FLAGS.overwrite and _dataset_exists(data = FLAGS.data, _NUM_SHARDS = FLAGS.num_shards):
        print('Dataset files already exist. Exiting without re-creating them.')
        print('Use --overwrite flag or remove them')
        return None
    #==========END OF CHECKS============

    #Get a list of photo_filenames like ['123.jpg', '456.jpg'...] and a list of sorted class names from parsing the subdirectories.
    photo_filenames, class_names = _get_filenames_and_classes(FLAGS.data)

    #Refer each of the class name to a specific integer number for predictions later
    class_names_to_ids = dict(zip(class_names, range(len(class_names))))

    #Find the number of validation examples we need
    num_validation = int(FLAGS.validation_size * len(photo_filenames))

    # Divide the training datasets into train and test:
    random.seed(FLAGS.random_seed)
    random.shuffle(photo_filenames)
    training_filenames = photo_filenames[num_validation:]
    validation_filenames = photo_filenames[:num_validation]

    # First, convert the training and validation sets.
    train_cnts = _convert_dataset('train', training_filenames, class_names_to_ids,
                     data = FLAGS.data, _NUM_SHARDS = 1)
    val_cnts = _convert_dataset('validation', validation_filenames, class_names_to_ids,
                     data = FLAGS.data, _NUM_SHARDS = 1)

    # Finally, write the labels file:
    labels_to_class_names = dict(zip(range(len(class_names)), class_names))
    dataset_utils.write_label_file(labels_to_class_names, FLAGS.data)

    total_train_cnt = sum(train_cnts)
    class_cnt  = len(train_cnts)
    class_weights = [ total_train_cnt/(train_cnts[i]*class_cnt+1e-10) for i in range(class_cnt) ]

    data_info = os.path.join(FLAGS.data, 'dinfo.npz')
    np.savez(data_info, train_cnt=total_train_cnt,
                        val_cnt=sum(val_cnts),
                        class_weights=class_weights,
                        classes=class_names
                        )

    print('\nFinished converting the dataset!')

if __name__ == "__main__":
    main()