diff options
Diffstat (limited to 'densenet.py')
-rwxr-xr-x | densenet.py | 256 |
1 files changed, 256 insertions, 0 deletions
diff --git a/densenet.py b/densenet.py new file mode 100755 index 0000000..22cf3cb --- /dev/null +++ b/densenet.py @@ -0,0 +1,256 @@ +#!/usr/bin/python3 +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +r"""DenseNet implemented with Keras running on Cloud TPUs. + +This file shows how you can run DenseNet on a Cloud TPU using the TensorFlow +Keras support. This is configured for ImageNet (e.g. 1000 classes), but you can +easily adapt to your own datasets by changing the code appropriately. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl import app +from absl import flags +from absl import logging +import numpy as np +import tensorflow as tf + +import eval_utils +import imagenet_input +from models.densenet import DenseNetImageNet121 +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.optimizer_v2 import gradient_descent, adam + +try: + import h5py as _ # pylint: disable=g-import-not-at-top + HAS_H5PY = True +except ImportError: + logging.warning('`h5py` is not installed. Please consider installing it ' + 'to save weights for long-running training.') + HAS_H5PY = False + + +# Imagenet training and test data sets. + +IMAGE_WIDTH = 320 +IMAGE_HEIGHT = 240 +EPOCHS = 90 # Standard imagenet training regime. + +# Training hyperparameters. +NUM_CORES = 8 +PER_CORE_BATCH_SIZE = 64 +CPU_BATCH_SIZE = 4 +BATCH_SIZE = NUM_CORES * PER_CORE_BATCH_SIZE +BASE_LEARNING_RATE = 0.4 +# Learning rate schedule +LR_SCHEDULE = [ # (multiplier, epoch to start) tuples + (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80) +] + +DEFAULT_WEIGHTS_H5 = 'resnet50_weights.h5' +DEFAULT_LOG_DIR = '/tmp/netcraft' +DEFAULT_BUCKET = 'gs://netcraft/' + +flags.DEFINE_integer('epochs', EPOCHS, '') +flags.DEFINE_string('weights', None, 'Use saved weights') +flags.DEFINE_string('bucket', DEFAULT_BUCKET, 'Bucket to use') +flags.DEFINE_string('tpu', None, 'Name of the TPU to use.') +flags.DEFINE_string('data', None, 'Path to training and testing data.') +flags.DEFINE_string( + 'log', DEFAULT_LOG_DIR, + ('The directory where the model weights and training/evaluation summaries ' + 'are stored. If not specified, save to /tmp/netcraft.')) +flags.DEFINE_bool( + 'complete_eval', True, + 'Eval both top 1 and top 5 accuracy. Otherwise, only eval top 1 accuracy. ' + 'Furthemore generate confusion matrixes and save softmax values in log_dir') +flags.DEFINE_bool('evalonly', False, 'Only run eval with given weights, do not train') + +FLAGS = flags.FLAGS + +def learning_rate_schedule(current_epoch, current_batch): + """Handles linear scaling rule, gradual warmup, and LR decay. + + The learning rate starts at 0, then it increases linearly per step. + After 5 epochs we reach the base learning rate (scaled to account + for batch size). + After 30, 60 and 80 epochs the learning rate is divided by 10. + After 90 epochs training stops and the LR is set to 0. This ensures + that we train for exactly 90 epochs for reproducibility. + + Args: + current_epoch: integer, current epoch indexed from 0. + current_batch: integer, current batch in the current epoch, indexed from 0. + + Returns: + Adjusted learning rate. + """ + return 0.0 + epoch = current_epoch + float(current_batch) / train_steps_per_epoch + warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0] + if epoch < warmup_end_epoch: + # Learning rate increases linearly per step. + return BASE_LEARNING_RATE * warmup_lr_multiplier * epoch / warmup_end_epoch + for mult, start_epoch in LR_SCHEDULE: + if epoch >= start_epoch: + learning_rate = BASE_LEARNING_RATE * mult + else: + break + return learning_rate + + +class LearningRateBatchScheduler(tf.keras.callbacks.Callback): + """Callback to update learning rate on every batch (not epoch boundaries). + + N.B. Only support Keras optimizers, not TF optimizers. + + Args: + schedule: a function that takes an epoch index and a batch index as input + (both integer, indexed from 0) and returns a new learning rate as + output (float). + """ + + def __init__(self, schedule): + super(LearningRateBatchScheduler, self).__init__() + self.schedule = schedule + self.epochs = -1 + self.prev_lr = -1 + + def on_epoch_begin(self, epoch, logs=None): + if not hasattr(self.model.optimizer, 'lr'): + raise ValueError('Optimizer must have a "lr" attribute.') + self.epochs += 1 + + def on_batch_begin(self, batch, logs=None): + lr = self.schedule(self.epochs, batch) + if not isinstance(lr, (float, np.float32, np.float64)): + raise ValueError('The output of the "schedule" function should be float.') + if lr != self.prev_lr: + K.set_value(self.model.optimizer.lr, lr) + self.prev_lr = lr + logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler change ' + 'learning rate to %s.', self.epochs, batch, lr) + + +def main(argv): + dinfo = np.load(os.path.join(FLAGS.data, 'dinfo.npz')) + classes = dinfo['classes'] + NUM_CLASSES = len(classes) + train_cnt = dinfo['train_cnt'] # 1141 # 50273 # Approximate number of images. + val_cnt = dinfo['val_cnt'] # 488 # 12560 # Number of images. + train_steps_per_epoch = int(train_cnt / BATCH_SIZE) + val_steps = int(val_cnt // BATCH_SIZE ) + + print("Using", train_cnt, "training images and", val_cnt, "for testing") + + logging.info('Building Keras DenseNet model') + model = DenseNetImageNet121(classes=NUM_CLASSES, weights=None) + + if FLAGS.tpu: + logging.info('Converting from CPU to TPU model.') + resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu) + strategy = tf.contrib.tpu.TPUDistributionStrategy(resolver) + model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy) + + logging.info('Compiling model.') + model.compile( + optimizer=gradient_descent.SGD(learning_rate=BASE_LEARNING_RATE, momentum=0.9, nesterov=True), + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + + if FLAGS.data is None: + training_images = np.random.randn( + BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3).astype(np.float32) + training_labels = np.random.randint(NUM_CLASSES, size=BATCH_SIZE, + dtype=np.int32) + logging.info('Training model using synthetica data.') + model.fit( + training_images, + training_labels, + epochs=EPOCHS, + batch_size=BATCH_SIZE) + logging.info('Evaluating the model on synthetic data.') + model.evaluate(training_images, training_labels, verbose=0) + else: + per_core_batch_size = PER_CORE_BATCH_SIZE if FLAGS.tpu else CPU_BATCH_SIZE + imagenet_train = imagenet_input.ImageNetInput( + is_training=True, + data_dir=FLAGS.bucket+FLAGS.data if FLAGS.tpu else FLAGS.data, + per_core_batch_size=per_core_batch_size) + logging.info('Training model using real data in directory "%s".', + FLAGS.data) + # If evaluating complete_eval, we feed the inputs from a Python generator, + # so we need to build a single batch for all of the cores, which will be + # split on TPU. + per_core_batch_size = ( + BATCH_SIZE if FLAGS.complete_eval else PER_CORE_BATCH_SIZE) + imagenet_validation = imagenet_input.ImageNetInput( + is_training=False, + data_dir=FLAGS.bucket+FLAGS.data if FLAGS.tpu else FLAGS.data, + per_core_batch_size=per_core_batch_size) + + eval_callback = eval_utils.TensorBoardWithValidation( + log_dir=FLAGS.log, + validation_imagenet_input=imagenet_validation, + validation_steps=val_steps, + validation_epochs=[ 3, 10, 30, 60, 90], + write_images=True, + write_graph=True, + plot_wrong=True, + plot_cm=True, + plot_pr=True, + classes=classes, + complete_eval=FLAGS.complete_eval) + + callbacks = [ + LearningRateBatchScheduler(schedule=learning_rate_schedule), + eval_callback + ] + + if FLAGS.tpu: + model_in = imagenet_train.input_fn + else: + model_in = imagenet_train.input_fn() + + + if FLAGS.weights: + weights_file = os.path.join(FLAGS.weights) + logging.info('Loading model and weights from %s', weights_file) + model.load_weights(weights_file) + else: + weights_file = os.path.join(DEFAULT_WEIGHTS_H5) + + if FLAGS.evalonly: + eval_callback.set_model(model) + eval_callback.on_epoch_end(420) + else: + model.fit(model_in, + epochs=EPOCHS, + steps_per_epoch=train_steps_per_epoch, + callbacks=callbacks) + + logging.info('Saving weights into %s', weights_file) + model.save_weights(weights_file, overwrite=True) + + +if __name__ == '__main__': + tf.logging.set_verbosity(tf.logging.INFO) + app.run(main) |