diff options
Diffstat (limited to 'resnet50.py')
-rwxr-xr-x | resnet50.py | 348 |
1 files changed, 348 insertions, 0 deletions
diff --git a/resnet50.py b/resnet50.py new file mode 100755 index 0000000..e063bfc --- /dev/null +++ b/resnet50.py @@ -0,0 +1,348 @@ +#!/usr/bin/python3 +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""ResNet-50 implemented with Keras running on Cloud TPUs. + +This file shows how you can run ResNet-50 on a Cloud TPU using the TensorFlow +Keras support. This is configured for ImageNet (e.g. 1000 classes), but you can +easily adapt to your own datasets by changing the code appropriately. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl import app +from absl import flags +from absl import logging +import numpy as np +import tensorflow as tf + +import eval_utils +import imagenet_input +import models +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.optimizer_v2 import gradient_descent, adam + +try: + import h5py as _ # pylint: disable=g-import-not-at-top + HAS_H5PY = True +except ImportError: + logging.warning('`h5py` is not installed. Please consider installing it ' + 'to save weights for long-running training.') + HAS_H5PY = False + + +# Imagenet training and test data sets. + +DEF_IMAGE_WIDTH = 320 +DEF_IMAGE_HEIGHT = 240 +DEF_EPOCHS = 90 # Standard imagenet training regime. + +# Training hyperparameters. +NUM_CORES = 8 +PER_CORE_BATCH_SIZE = 4 +CPU_BATCH_SIZE = 1 +BASE_LEARNING_RATE = 1e-3 +# Learning rate schedule +LR_SCHEDULE = [ # (multiplier, epoch to start) tuples + (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80) +] + +DEFAULT_WEIGHTS_H5 = 'resnet50_weights.h5' +DEFAULT_LOG_DIR = '/tmp/netcraft' +DEFAULT_BUCKET = 'gs://netcraft/' + +flags.DEFINE_float('lr', BASE_LEARNING_RATE, 'Defines the step size when training') +flags.DEFINE_integer('epochs', DEF_EPOCHS, 'Number of epochs until which to train') +flags.DEFINE_integer('split_epochs', 1, 'Split epochs into smaller bits, helps save weights') +flags.DEFINE_integer('initial_epoch', 0, 'Epoch from which to start, useful when resuming training') +flags.DEFINE_integer('image_width', DEF_IMAGE_WIDTH, '') +flags.DEFINE_integer('image_height', DEF_IMAGE_HEIGHT, '') +flags.DEFINE_string('weights', None, 'Use saved weights') +flags.DEFINE_string('weights2', None, 'Use another saved weights') +flags.DEFINE_string('bucket', DEFAULT_BUCKET, 'Bucket to use') +flags.DEFINE_string('tpu', None, 'Name of the TPU to use.') +flags.DEFINE_string('data', None, 'Path to training and testing data.') +flags.DEFINE_string('model', "resnet", 'Which logo to use (resnet, or logo)') +flags.DEFINE_string( + 'log', DEFAULT_LOG_DIR, + ('The directory where the model weights and training/evaluation summaries ' + 'are stored. If not specified, save to /tmp/netcraft.')) +flags.DEFINE_bool( + 'complete_eval', True, + 'Eval both top 1 and top 5 accuracy. Otherwise, only eval top 1 accuracy. ' + 'Furthemore generate confusion matrixes and save softmax values in log_dir') +flags.DEFINE_bool('evalonly', False, 'Only run eval with given weights, do not train') +flags.DEFINE_bool('class_weights', False, 'Use class weights to deal with imbalanced dataset') +flags.DEFINE_integer('benign_multiplier', 1, 'Multiplier for weigh tof benign class') +flags.DEFINE_bool('plot_wrong', False, 'Plot false images in tensorboard, make eval slower') +flags.DEFINE_bool('plot_cm', True, 'Plot confusion matrix in tensorboard') +flags.DEFINE_bool('plot_pr', True, 'Plot precision recall in tensorboard') +flags.DEFINE_bool('weights_by_name', False, 'Load weights by name, this allows loading weights with an incompatible fully '+ + 'connect layer i.e. a different number of targets. The FC layer is randomly initiated and needs to be trained.') + +FLAGS = flags.FLAGS + +def learning_rate_schedule(current_epoch, current_batch, train_steps_per_epoch, base_learning_rate): + """Handles linear scaling rule, gradual warmup, and LR decay. + + The learning rate starts at 0, then it increases linearly per step. + After 5 epochs we reach the base learning rate (scaled to account + for batch size). + After 30, 60 and 80 epochs the learning rate is divided by 10. + After 90 epochs training stops and the LR is set to 0. This ensures + that we train for exactly 90 epochs for reproducibility. + + Args: + current_epoch: integer, current epoch indexed from 0. + current_batch: integer, current batch in the current epoch, indexed from 0. + + Returns: + Adjusted learning rate. + """ + epoch = current_epoch + float(current_batch) / train_steps_per_epoch + warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0] + if epoch < warmup_end_epoch: + # Learning rate increases linearly per step. + return base_learning_rate * warmup_lr_multiplier * epoch / warmup_end_epoch + for mult, start_epoch in LR_SCHEDULE: + if epoch >= start_epoch: + learning_rate = base_learning_rate * mult + else: + break + return learning_rate + + +class LearningRateBatchScheduler(tf.keras.callbacks.Callback): + """Callback to update learning rate on every batch (not epoch boundaries). + + N.B. Only support Keras optimizers, not TF optimizers. + + Args: + schedule: a function that takes an epoch index and a batch index as input + (both integer, indexed from 0) and returns a new learning rate as + output (float). + """ + + def __init__(self, schedule, train_steps_per_epoch, base_learning_rate): + super(LearningRateBatchScheduler, self).__init__() + self.base_lr = base_learning_rate + self.schedule = schedule + self.train_steps_per_epoch = train_steps_per_epoch + self.epochs = -1 + self.prev_lr = -1 + + def on_epoch_begin(self, epoch, logs=None): + if not hasattr(self.model.optimizer, 'lr'): + raise ValueError('Optimizer must have a "lr" attribute.') + self.epochs += 1 + + def on_batch_begin(self, batch, logs=None): + lr = self.schedule(self.epochs, batch, self.train_steps_per_epoch, self.base_lr) + if not isinstance(lr, (float, np.float32, np.float64)): + raise ValueError('The output of the "schedule" function should be float.') + if lr != self.prev_lr: + K.set_value(self.model.optimizer.lr, lr) + self.prev_lr = lr + logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler change ' + 'learning rate to %s.', self.epochs, batch, lr) + + +def main(argv): + if FLAGS.data: + dinfo = np.load(os.path.join(FLAGS.data, 'dinfo.npz'), allow_pickle=True) + classes = dinfo['classes'] + num_classes = len(classes) + train_cnt = dinfo['train_cnt'] # 1141 # 50273 # Approximate number of images. + val_cnt = dinfo['val_cnt'] # 488 # 12560 # Number of images. + class_weights = dinfo['class_weights'].tolist() + #class_weights = class_weights[()] # Unpack 0d np.array + + if FLAGS.class_weights and FLAGS.benign_multiplier != 1: + benign_class = np.squeeze(np.where(classes=='benign')) + if benign_class: + benign_class = np.asscalar(benign_class) + class_weights[benign_class] *= FLAGS.benign_multiplier + else: + logging.warning("Could not find benign class. Ignoring benign multiplier.") + else: + train_cnt = 10e5 + val_cnt = 10e4 + num_classes = 10e2 + + if FLAGS.tpu: + batch_size = NUM_CORES * PER_CORE_BATCH_SIZE + else: + batch_size = CPU_BATCH_SIZE + + train_steps_per_epoch = int(train_cnt / (batch_size * FLAGS.split_epochs)) + val_steps = int(val_cnt // batch_size ) + + logging.info("Using %d training images and %d for validation", train_cnt, val_cnt) + + if FLAGS.model == 'resnet': + logging.info('Building Keras ResNet-50 model') + model = models.ResNet50(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=num_classes) + elif FLAGS.model == 'combined': + logging.info('Building Keras ResNet-50 + LOGO model') + model = models.get_logores_model(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=num_classes, resnet_trainable=False) + elif FLAGS.model == 'combined_trainable': + logging.info('Building Keras ResNet-50 + LOGO model') + model = models.get_logores_model(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=num_classes, resnet_trainable=True) + elif FLAGS.model == 'logo': + logging.info('Building LogoNet model') + model = models.get_logo_model(width=None, height=None, num_classes=num_classes, base_trainable=True) + elif FLAGS.model == 'logo_extended': + logging.info('Building LogoNet model') + model = models.get_logo_model(width=FLAGS.image_width, height=FLAGS.image_height, base_trainable=False, num_classes=num_classes) + elif FLAGS.model == 'logo_new': + logging.info('Building LogoNet model') + model = models.get_logo_model_new(width=FLAGS.image_width, height=FLAGS.image_height, base_trainable=False, num_classes=num_classes) + elif FLAGS.model == 'logo_extended_trainable': + logging.info('Building LogoNet model') + model = models.get_logo_model(width=FLAGS.image_width, height=FLAGS.image_height, base_trainable=True, num_classes=num_classes) + else: + return 'Only valid models are resnet and logo' + + if FLAGS.tpu: + logging.info('Converting from CPU to TPU model.') + resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu) + strategy = tf.contrib.tpu.TPUDistributionStrategy(resolver) + model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy) + + logging.info('Compiling model.') + model.compile( + optimizer=adam.Adam(learning_rate=FLAGS.lr), + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + + if FLAGS.data is None: + training_images = np.random.randn( + batch_size, FLAGS.image_height, FLAGS.image_width, 3).astype(np.float32) + training_labels = np.random.randint(num_classes, size=batch_size, + dtype=np.int32) + logging.info('Training model using synthetica data, use --data flag to provided real data.') + model.fit( + training_images, + training_labels, + epochs=FLAGS.epochs, + initial_epoch=FLAGS.initial_epoch, + batch_size=batch_size) + logging.info('Evaluating the model on synthetic data.') + model.evaluate(training_images, training_labels, verbose=0) + else: + per_core_batch_size = PER_CORE_BATCH_SIZE if FLAGS.tpu else CPU_BATCH_SIZE + imagenet_train = imagenet_input.ImageNetInput( + width=FLAGS.image_width, + height=FLAGS.image_height, + resize=False if (FLAGS.model == 'logo') else True, + is_training=True, + data_dir=FLAGS.bucket+FLAGS.data if FLAGS.tpu else FLAGS.data, + per_core_batch_size=per_core_batch_size) + logging.info('Training model using real data in directory "%s".', + FLAGS.data) + # If evaluating complete_eval, we feed the inputs from a Python generator, + # so we need to build a single batch for all of the cores, which will be + # split on TPU. + per_core_batch_size = ( + batch_size if (FLAGS.complete_eval or not FLAGS.tpu) else PER_CORE_BATCH_SIZE) + imagenet_validation = imagenet_input.ImageNetInput( + FLAGS.image_width, FLAGS.image_height, + resize=False if (FLAGS.model == 'logo') else True, + is_training=False, + data_dir=FLAGS.bucket+FLAGS.data if FLAGS.tpu else FLAGS.data, + per_core_batch_size=per_core_batch_size) + + if FLAGS.evalonly: + validation_epochs= [420] + logging.info("Only running a single validation epoch") + else: + validation_epochs=[ 3, 10, 30, 60, 90] + logging.info("Validation will be run on epochs %s", str(validation_epochs)) + + eval_callback = eval_utils.TensorBoardWithValidation( + log_dir=FLAGS.log, + validation_imagenet_input=imagenet_validation, + validation_steps=val_steps, + validation_epochs=validation_epochs, + write_images=True, + write_graph=True, + plot_wrong=FLAGS.plot_wrong, + plot_cm=FLAGS.plot_cm, + plot_pr=FLAGS.plot_pr, + classes=classes, + complete_eval=FLAGS.complete_eval) + + callbacks = [ + tf.keras.callbacks.ModelCheckpoint(FLAGS.log+"/weights.{epoch:02d}-{sparse_categorical_accuracy:.2f}.hdf5", + monitor='sparse_categorical_accuracy', verbose=1, + save_best_only=True, save_weights_only=True, mode='auto'), + LearningRateBatchScheduler(schedule=learning_rate_schedule, train_steps_per_epoch=train_steps_per_epoch, base_learning_rate=FLAGS.lr), + eval_callback + ] + + if FLAGS.tpu: + model_in = imagenet_train.input_fn + else: + model_in = imagenet_train.input_fn() + + preloaded_weights = [] + for layer in model.layers: + preloaded_weights.append(layer.get_weights()) + + if FLAGS.weights: + weights_file = os.path.join(FLAGS.weights) + logging.info('Loading trained weights from %s', weights_file) + model.load_weights(weights_file, by_name=FLAGS.weights_by_name) + if FLAGS.weights2: + weights2_file = os.path.join(FLAGS.weights2) + logging.info('Loading secondary trained weights from %s', weights2_file) + model.load_weights(weights2_file, by_name=FLAGS.weights_by_name) + else: + if FLAGS.weights2: + logging.debug("Ignoring --weights2 flag as no --weights") + weights_file = os.path.join(DEFAULT_WEIGHTS_H5) + + # Check if we loaded weights + for layer, pre in zip(model.layers, preloaded_weights): + weights = layer.get_weights() + + populated=True + if weights: + for weight, pr in zip(weights, pre): + if np.array_equal(weight, pr): + populated=False + + if not populated: + logging.warning('Layer %s not populated with weights!', layer.name) + + if FLAGS.evalonly: + eval_callback.set_model(model) + eval_callback.on_epoch_end(420) + else: + model.fit(model_in, + epochs=FLAGS.epochs, + initial_epoch=FLAGS.initial_epoch, + class_weight = class_weights if FLAGS.class_weights else None, + steps_per_epoch=train_steps_per_epoch, + callbacks=callbacks) + +if __name__ == '__main__': + tf.logging.set_verbosity(tf.logging.INFO) + app.run(main) |