summaryrefslogtreecommitdiff
path: root/util/dataset_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'util/dataset_utils.py')
-rw-r--r--util/dataset_utils.py150
1 files changed, 150 insertions, 0 deletions
diff --git a/util/dataset_utils.py b/util/dataset_utils.py
new file mode 100644
index 0000000..fdaefca
--- /dev/null
+++ b/util/dataset_utils.py
@@ -0,0 +1,150 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains utilities for downloading and converting datasets."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+import tarfile
+
+from six.moves import urllib
+import tensorflow as tf
+
+LABELS_FILENAME = 'labels.txt'
+
+
+def int64_feature(values):
+ """Returns a TF-Feature of int64s.
+
+ Args:
+ values: A scalar or list of values.
+
+ Returns:
+ A TF-Feature.
+ """
+ if not isinstance(values, (tuple, list)):
+ values = [values]
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
+
+
+def bytes_feature(values):
+ """Returns a TF-Feature of bytes.
+
+ Args:
+ values: A string.
+
+ Returns:
+ A TF-Feature.
+ """
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
+
+
+def float_feature(values):
+ """Returns a TF-Feature of floats.
+
+ Args:
+ values: A scalar of list of values.
+
+ Returns:
+ A TF-Feature.
+ """
+ if not isinstance(values, (tuple, list)):
+ values = [values]
+ return tf.train.Feature(float_list=tf.train.FloatList(value=values))
+
+
+def image_to_tfexample(image_data, image_format, height, width, class_id):
+ return tf.train.Example(features=tf.train.Features(feature={
+ 'image/encoded': bytes_feature(image_data),
+ 'image/format': bytes_feature(image_format),
+ 'image/class/label': int64_feature(class_id),
+ 'image/height': int64_feature(height),
+ 'image/width': int64_feature(width),
+ }))
+
+
+def download_and_uncompress_tarball(tarball_url, dataset_dir):
+ """Downloads the `tarball_url` and uncompresses it locally.
+
+ Args:
+ tarball_url: The URL of a tarball file.
+ dataset_dir: The directory where the temporary files are stored.
+ """
+ filename = tarball_url.split('/')[-1]
+ filepath = os.path.join(dataset_dir, filename)
+
+ def _progress(count, block_size, total_size):
+ sys.stdout.write('\r>> Downloading %s %.1f%%' % (
+ filename, float(count * block_size) / float(total_size) * 100.0))
+ sys.stdout.flush()
+ filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
+ print()
+ statinfo = os.stat(filepath)
+ print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
+ tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
+
+
+def write_label_file(labels_to_class_names, dataset_dir,
+ filename=LABELS_FILENAME):
+ """Writes a file with the list of class names.
+
+ Args:
+ labels_to_class_names: A map of (integer) labels to class names.
+ dataset_dir: The directory in which the labels file should be written.
+ filename: The filename where the class names are written.
+ """
+ labels_filename = os.path.join(dataset_dir, filename)
+ with tf.gfile.Open(labels_filename, 'w') as f:
+ for label in labels_to_class_names:
+ class_name = labels_to_class_names[label]
+ f.write('%d:%s\n' % (label, class_name))
+
+
+def has_labels(dataset_dir, filename=LABELS_FILENAME):
+ """Specifies whether or not the dataset directory contains a label map file.
+
+ Args:
+ dataset_dir: The directory in which the labels file is found.
+ filename: The filename where the class names are written.
+
+ Returns:
+ `True` if the labels file exists and `False` otherwise.
+ """
+ return tf.gfile.Exists(os.path.join(dataset_dir, filename))
+
+
+def read_label_file(dataset_dir, filename=LABELS_FILENAME):
+ """Reads the labels file and returns a mapping from ID to class name.
+
+ Args:
+ dataset_dir: The directory in which the labels file is found.
+ filename: The filename where the class names are written.
+
+ Returns:
+ A map from a label (integer) to class name.
+ """
+ labels_filename = os.path.join(dataset_dir, filename)
+ with tf.gfile.Open(labels_filename, 'rb') as f:
+ lines = f.read().decode()
+ lines = lines.split('\n')
+ lines = filter(None, lines)
+
+ labels_to_class_names = {}
+ for line in lines:
+ index = line.index(':')
+ labels_to_class_names[int(line[:index])] = line[index+1:]
+ return labels_to_class_names