summaryrefslogtreecommitdiff
path: root/densenet.py
diff options
context:
space:
mode:
Diffstat (limited to 'densenet.py')
-rwxr-xr-xdensenet.py256
1 files changed, 256 insertions, 0 deletions
diff --git a/densenet.py b/densenet.py
new file mode 100755
index 0000000..22cf3cb
--- /dev/null
+++ b/densenet.py
@@ -0,0 +1,256 @@
+#!/usr/bin/python3
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+r"""DenseNet implemented with Keras running on Cloud TPUs.
+
+This file shows how you can run DenseNet on a Cloud TPU using the TensorFlow
+Keras support. This is configured for ImageNet (e.g. 1000 classes), but you can
+easily adapt to your own datasets by changing the code appropriately.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import app
+from absl import flags
+from absl import logging
+import numpy as np
+import tensorflow as tf
+
+import eval_utils
+import imagenet_input
+from models.densenet import DenseNetImageNet121
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.optimizer_v2 import gradient_descent, adam
+
+try:
+ import h5py as _ # pylint: disable=g-import-not-at-top
+ HAS_H5PY = True
+except ImportError:
+ logging.warning('`h5py` is not installed. Please consider installing it '
+ 'to save weights for long-running training.')
+ HAS_H5PY = False
+
+
+# Imagenet training and test data sets.
+
+IMAGE_WIDTH = 320
+IMAGE_HEIGHT = 240
+EPOCHS = 90 # Standard imagenet training regime.
+
+# Training hyperparameters.
+NUM_CORES = 8
+PER_CORE_BATCH_SIZE = 64
+CPU_BATCH_SIZE = 4
+BATCH_SIZE = NUM_CORES * PER_CORE_BATCH_SIZE
+BASE_LEARNING_RATE = 0.4
+# Learning rate schedule
+LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
+ (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
+]
+
+DEFAULT_WEIGHTS_H5 = 'resnet50_weights.h5'
+DEFAULT_LOG_DIR = '/tmp/netcraft'
+DEFAULT_BUCKET = 'gs://netcraft/'
+
+flags.DEFINE_integer('epochs', EPOCHS, '')
+flags.DEFINE_string('weights', None, 'Use saved weights')
+flags.DEFINE_string('bucket', DEFAULT_BUCKET, 'Bucket to use')
+flags.DEFINE_string('tpu', None, 'Name of the TPU to use.')
+flags.DEFINE_string('data', None, 'Path to training and testing data.')
+flags.DEFINE_string(
+ 'log', DEFAULT_LOG_DIR,
+ ('The directory where the model weights and training/evaluation summaries '
+ 'are stored. If not specified, save to /tmp/netcraft.'))
+flags.DEFINE_bool(
+ 'complete_eval', True,
+ 'Eval both top 1 and top 5 accuracy. Otherwise, only eval top 1 accuracy. '
+ 'Furthemore generate confusion matrixes and save softmax values in log_dir')
+flags.DEFINE_bool('evalonly', False, 'Only run eval with given weights, do not train')
+
+FLAGS = flags.FLAGS
+
+def learning_rate_schedule(current_epoch, current_batch):
+ """Handles linear scaling rule, gradual warmup, and LR decay.
+
+ The learning rate starts at 0, then it increases linearly per step.
+ After 5 epochs we reach the base learning rate (scaled to account
+ for batch size).
+ After 30, 60 and 80 epochs the learning rate is divided by 10.
+ After 90 epochs training stops and the LR is set to 0. This ensures
+ that we train for exactly 90 epochs for reproducibility.
+
+ Args:
+ current_epoch: integer, current epoch indexed from 0.
+ current_batch: integer, current batch in the current epoch, indexed from 0.
+
+ Returns:
+ Adjusted learning rate.
+ """
+ return 0.0
+ epoch = current_epoch + float(current_batch) / train_steps_per_epoch
+ warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
+ if epoch < warmup_end_epoch:
+ # Learning rate increases linearly per step.
+ return BASE_LEARNING_RATE * warmup_lr_multiplier * epoch / warmup_end_epoch
+ for mult, start_epoch in LR_SCHEDULE:
+ if epoch >= start_epoch:
+ learning_rate = BASE_LEARNING_RATE * mult
+ else:
+ break
+ return learning_rate
+
+
+class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
+ """Callback to update learning rate on every batch (not epoch boundaries).
+
+ N.B. Only support Keras optimizers, not TF optimizers.
+
+ Args:
+ schedule: a function that takes an epoch index and a batch index as input
+ (both integer, indexed from 0) and returns a new learning rate as
+ output (float).
+ """
+
+ def __init__(self, schedule):
+ super(LearningRateBatchScheduler, self).__init__()
+ self.schedule = schedule
+ self.epochs = -1
+ self.prev_lr = -1
+
+ def on_epoch_begin(self, epoch, logs=None):
+ if not hasattr(self.model.optimizer, 'lr'):
+ raise ValueError('Optimizer must have a "lr" attribute.')
+ self.epochs += 1
+
+ def on_batch_begin(self, batch, logs=None):
+ lr = self.schedule(self.epochs, batch)
+ if not isinstance(lr, (float, np.float32, np.float64)):
+ raise ValueError('The output of the "schedule" function should be float.')
+ if lr != self.prev_lr:
+ K.set_value(self.model.optimizer.lr, lr)
+ self.prev_lr = lr
+ logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler change '
+ 'learning rate to %s.', self.epochs, batch, lr)
+
+
+def main(argv):
+ dinfo = np.load(os.path.join(FLAGS.data, 'dinfo.npz'))
+ classes = dinfo['classes']
+ NUM_CLASSES = len(classes)
+ train_cnt = dinfo['train_cnt'] # 1141 # 50273 # Approximate number of images.
+ val_cnt = dinfo['val_cnt'] # 488 # 12560 # Number of images.
+ train_steps_per_epoch = int(train_cnt / BATCH_SIZE)
+ val_steps = int(val_cnt // BATCH_SIZE )
+
+ print("Using", train_cnt, "training images and", val_cnt, "for testing")
+
+ logging.info('Building Keras DenseNet model')
+ model = DenseNetImageNet121(classes=NUM_CLASSES, weights=None)
+
+ if FLAGS.tpu:
+ logging.info('Converting from CPU to TPU model.')
+ resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
+ strategy = tf.contrib.tpu.TPUDistributionStrategy(resolver)
+ model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
+
+ logging.info('Compiling model.')
+ model.compile(
+ optimizer=gradient_descent.SGD(learning_rate=BASE_LEARNING_RATE, momentum=0.9, nesterov=True),
+ loss='sparse_categorical_crossentropy',
+ metrics=['sparse_categorical_accuracy'])
+
+ if FLAGS.data is None:
+ training_images = np.random.randn(
+ BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3).astype(np.float32)
+ training_labels = np.random.randint(NUM_CLASSES, size=BATCH_SIZE,
+ dtype=np.int32)
+ logging.info('Training model using synthetica data.')
+ model.fit(
+ training_images,
+ training_labels,
+ epochs=EPOCHS,
+ batch_size=BATCH_SIZE)
+ logging.info('Evaluating the model on synthetic data.')
+ model.evaluate(training_images, training_labels, verbose=0)
+ else:
+ per_core_batch_size = PER_CORE_BATCH_SIZE if FLAGS.tpu else CPU_BATCH_SIZE
+ imagenet_train = imagenet_input.ImageNetInput(
+ is_training=True,
+ data_dir=FLAGS.bucket+FLAGS.data if FLAGS.tpu else FLAGS.data,
+ per_core_batch_size=per_core_batch_size)
+ logging.info('Training model using real data in directory "%s".',
+ FLAGS.data)
+ # If evaluating complete_eval, we feed the inputs from a Python generator,
+ # so we need to build a single batch for all of the cores, which will be
+ # split on TPU.
+ per_core_batch_size = (
+ BATCH_SIZE if FLAGS.complete_eval else PER_CORE_BATCH_SIZE)
+ imagenet_validation = imagenet_input.ImageNetInput(
+ is_training=False,
+ data_dir=FLAGS.bucket+FLAGS.data if FLAGS.tpu else FLAGS.data,
+ per_core_batch_size=per_core_batch_size)
+
+ eval_callback = eval_utils.TensorBoardWithValidation(
+ log_dir=FLAGS.log,
+ validation_imagenet_input=imagenet_validation,
+ validation_steps=val_steps,
+ validation_epochs=[ 3, 10, 30, 60, 90],
+ write_images=True,
+ write_graph=True,
+ plot_wrong=True,
+ plot_cm=True,
+ plot_pr=True,
+ classes=classes,
+ complete_eval=FLAGS.complete_eval)
+
+ callbacks = [
+ LearningRateBatchScheduler(schedule=learning_rate_schedule),
+ eval_callback
+ ]
+
+ if FLAGS.tpu:
+ model_in = imagenet_train.input_fn
+ else:
+ model_in = imagenet_train.input_fn()
+
+
+ if FLAGS.weights:
+ weights_file = os.path.join(FLAGS.weights)
+ logging.info('Loading model and weights from %s', weights_file)
+ model.load_weights(weights_file)
+ else:
+ weights_file = os.path.join(DEFAULT_WEIGHTS_H5)
+
+ if FLAGS.evalonly:
+ eval_callback.set_model(model)
+ eval_callback.on_epoch_end(420)
+ else:
+ model.fit(model_in,
+ epochs=EPOCHS,
+ steps_per_epoch=train_steps_per_epoch,
+ callbacks=callbacks)
+
+ logging.info('Saving weights into %s', weights_file)
+ model.save_weights(weights_file, overwrite=True)
+
+
+if __name__ == '__main__':
+ tf.logging.set_verbosity(tf.logging.INFO)
+ app.run(main)