diff options
Diffstat (limited to 'resnet50.py')
| -rwxr-xr-x | resnet50.py | 348 | 
1 files changed, 348 insertions, 0 deletions
diff --git a/resnet50.py b/resnet50.py new file mode 100755 index 0000000..e063bfc --- /dev/null +++ b/resnet50.py @@ -0,0 +1,348 @@ +#!/usr/bin/python3 +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +#     http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""ResNet-50 implemented with Keras running on Cloud TPUs. + +This file shows how you can run ResNet-50 on a Cloud TPU using the TensorFlow +Keras support. This is configured for ImageNet (e.g. 1000 classes), but you can +easily adapt to your own datasets by changing the code appropriately. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl import app +from absl import flags +from absl import logging +import numpy as np +import tensorflow as tf + +import eval_utils +import imagenet_input +import models +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.optimizer_v2 import gradient_descent, adam + +try: +  import h5py as _  # pylint: disable=g-import-not-at-top +  HAS_H5PY = True +except ImportError: +  logging.warning('`h5py` is not installed. Please consider installing it ' +                  'to save weights for long-running training.') +  HAS_H5PY = False + + +# Imagenet training and test data sets. + +DEF_IMAGE_WIDTH  = 320 +DEF_IMAGE_HEIGHT = 240 +DEF_EPOCHS = 90  # Standard imagenet training regime. + +# Training hyperparameters. +NUM_CORES = 8 +PER_CORE_BATCH_SIZE = 4 +CPU_BATCH_SIZE = 1 +BASE_LEARNING_RATE = 1e-3 +# Learning rate schedule +LR_SCHEDULE = [    # (multiplier, epoch to start) tuples +    (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80) +] + +DEFAULT_WEIGHTS_H5 = 'resnet50_weights.h5' +DEFAULT_LOG_DIR = '/tmp/netcraft' +DEFAULT_BUCKET = 'gs://netcraft/' + +flags.DEFINE_float('lr', BASE_LEARNING_RATE, 'Defines the step size when training') +flags.DEFINE_integer('epochs', DEF_EPOCHS, 'Number of epochs until which to train') +flags.DEFINE_integer('split_epochs', 1, 'Split epochs into smaller bits, helps save weights') +flags.DEFINE_integer('initial_epoch', 0, 'Epoch from which to start, useful when resuming training') +flags.DEFINE_integer('image_width', DEF_IMAGE_WIDTH, '') +flags.DEFINE_integer('image_height', DEF_IMAGE_HEIGHT, '') +flags.DEFINE_string('weights', None, 'Use saved weights') +flags.DEFINE_string('weights2', None, 'Use another saved weights') +flags.DEFINE_string('bucket', DEFAULT_BUCKET, 'Bucket to use') +flags.DEFINE_string('tpu', None, 'Name of the TPU to use.') +flags.DEFINE_string('data', None, 'Path to training and testing data.') +flags.DEFINE_string('model', "resnet", 'Which logo to use (resnet, or logo)') +flags.DEFINE_string( +    'log', DEFAULT_LOG_DIR, +    ('The directory where the model weights and training/evaluation summaries ' +     'are stored. If not specified, save to /tmp/netcraft.')) +flags.DEFINE_bool( +    'complete_eval', True, +    'Eval both top 1 and top 5 accuracy. Otherwise, only eval top 1 accuracy. ' +    'Furthemore generate confusion matrixes and save softmax values in log_dir') +flags.DEFINE_bool('evalonly', False, 'Only run eval with given weights, do not train') +flags.DEFINE_bool('class_weights', False, 'Use class weights to deal with imbalanced dataset') +flags.DEFINE_integer('benign_multiplier', 1, 'Multiplier for weigh tof benign class') +flags.DEFINE_bool('plot_wrong', False, 'Plot false images in tensorboard, make eval slower') +flags.DEFINE_bool('plot_cm', True, 'Plot confusion matrix in tensorboard') +flags.DEFINE_bool('plot_pr', True, 'Plot precision recall in tensorboard') +flags.DEFINE_bool('weights_by_name', False, 'Load weights by name, this allows loading weights with an incompatible fully '+ +                  'connect layer i.e. a different number of targets. The FC layer is randomly initiated and needs to be trained.') + +FLAGS = flags.FLAGS + +def learning_rate_schedule(current_epoch, current_batch, train_steps_per_epoch, base_learning_rate): +  """Handles linear scaling rule, gradual warmup, and LR decay. + +  The learning rate starts at 0, then it increases linearly per step. +  After 5 epochs we reach the base learning rate (scaled to account +    for batch size). +  After 30, 60 and 80 epochs the learning rate is divided by 10. +  After 90 epochs training stops and the LR is set to 0. This ensures +    that we train for exactly 90 epochs for reproducibility. + +  Args: +    current_epoch: integer, current epoch indexed from 0. +    current_batch: integer, current batch in the current epoch, indexed from 0. + +  Returns: +    Adjusted learning rate. +  """ +  epoch = current_epoch + float(current_batch) / train_steps_per_epoch +  warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0] +  if epoch < warmup_end_epoch: +    # Learning rate increases linearly per step. +    return base_learning_rate * warmup_lr_multiplier * epoch / warmup_end_epoch +  for mult, start_epoch in LR_SCHEDULE: +    if epoch >= start_epoch: +      learning_rate = base_learning_rate * mult +    else: +      break +  return learning_rate + + +class LearningRateBatchScheduler(tf.keras.callbacks.Callback): +  """Callback to update learning rate on every batch (not epoch boundaries). + +  N.B. Only support Keras optimizers, not TF optimizers. + +  Args: +      schedule: a function that takes an epoch index and a batch index as input +          (both integer, indexed from 0) and returns a new learning rate as +          output (float). +  """ + +  def __init__(self, schedule, train_steps_per_epoch, base_learning_rate): +    super(LearningRateBatchScheduler, self).__init__() +    self.base_lr = base_learning_rate +    self.schedule = schedule +    self.train_steps_per_epoch = train_steps_per_epoch +    self.epochs = -1 +    self.prev_lr = -1 + +  def on_epoch_begin(self, epoch, logs=None): +    if not hasattr(self.model.optimizer, 'lr'): +      raise ValueError('Optimizer must have a "lr" attribute.') +    self.epochs += 1 + +  def on_batch_begin(self, batch, logs=None): +    lr = self.schedule(self.epochs, batch, self.train_steps_per_epoch, self.base_lr) +    if not isinstance(lr, (float, np.float32, np.float64)): +      raise ValueError('The output of the "schedule" function should be float.') +    if lr != self.prev_lr: +      K.set_value(self.model.optimizer.lr, lr) +      self.prev_lr = lr +      logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler change ' +                    'learning rate to %s.', self.epochs, batch, lr) + + +def main(argv): +  if FLAGS.data: +      dinfo = np.load(os.path.join(FLAGS.data, 'dinfo.npz'), allow_pickle=True) +      classes = dinfo['classes'] +      num_classes = len(classes) +      train_cnt = dinfo['train_cnt'] # 1141 # 50273 # Approximate number of images. +      val_cnt = dinfo['val_cnt'] # 488 # 12560  # Number of images. +      class_weights = dinfo['class_weights'].tolist() +      #class_weights = class_weights[()] # Unpack 0d np.array + +      if FLAGS.class_weights and FLAGS.benign_multiplier != 1: +          benign_class = np.squeeze(np.where(classes=='benign')) +          if benign_class: +              benign_class = np.asscalar(benign_class) +              class_weights[benign_class] *= FLAGS.benign_multiplier +          else: +              logging.warning("Could not find benign class. Ignoring benign multiplier.") +  else: +      train_cnt = 10e5 +      val_cnt = 10e4 +      num_classes = 10e2 + +  if FLAGS.tpu: +      batch_size = NUM_CORES * PER_CORE_BATCH_SIZE +  else: +      batch_size = CPU_BATCH_SIZE + +  train_steps_per_epoch = int(train_cnt / (batch_size * FLAGS.split_epochs)) +  val_steps = int(val_cnt // batch_size ) + +  logging.info("Using %d training images and %d for validation", train_cnt, val_cnt) + +  if FLAGS.model == 'resnet': +      logging.info('Building Keras ResNet-50 model') +      model = models.ResNet50(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=num_classes) +  elif FLAGS.model == 'combined': +      logging.info('Building Keras ResNet-50 + LOGO model') +      model = models.get_logores_model(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=num_classes, resnet_trainable=False) +  elif FLAGS.model == 'combined_trainable': +      logging.info('Building Keras ResNet-50 + LOGO model') +      model = models.get_logores_model(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=num_classes, resnet_trainable=True) +  elif FLAGS.model == 'logo': +      logging.info('Building LogoNet model') +      model = models.get_logo_model(width=None, height=None, num_classes=num_classes, base_trainable=True) +  elif FLAGS.model == 'logo_extended': +      logging.info('Building LogoNet model') +      model = models.get_logo_model(width=FLAGS.image_width, height=FLAGS.image_height, base_trainable=False, num_classes=num_classes) +  elif FLAGS.model == 'logo_new': +      logging.info('Building LogoNet model') +      model = models.get_logo_model_new(width=FLAGS.image_width, height=FLAGS.image_height, base_trainable=False, num_classes=num_classes) +  elif FLAGS.model == 'logo_extended_trainable': +      logging.info('Building LogoNet model') +      model = models.get_logo_model(width=FLAGS.image_width, height=FLAGS.image_height, base_trainable=True, num_classes=num_classes) +  else: +      return 'Only valid models are resnet and logo' + +  if FLAGS.tpu: +    logging.info('Converting from CPU to TPU model.') +    resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu) +    strategy = tf.contrib.tpu.TPUDistributionStrategy(resolver) +    model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy) + +  logging.info('Compiling model.') +  model.compile( +      optimizer=adam.Adam(learning_rate=FLAGS.lr), +      loss='sparse_categorical_crossentropy', +      metrics=['sparse_categorical_accuracy']) + +  if FLAGS.data is None: +    training_images = np.random.randn( +        batch_size, FLAGS.image_height, FLAGS.image_width, 3).astype(np.float32) +    training_labels = np.random.randint(num_classes, size=batch_size, +                                        dtype=np.int32) +    logging.info('Training model using synthetica data, use --data flag to provided real data.') +    model.fit( +        training_images, +        training_labels, +        epochs=FLAGS.epochs, +        initial_epoch=FLAGS.initial_epoch, +        batch_size=batch_size) +    logging.info('Evaluating the model on synthetic data.') +    model.evaluate(training_images, training_labels, verbose=0) +  else: +    per_core_batch_size = PER_CORE_BATCH_SIZE if FLAGS.tpu else CPU_BATCH_SIZE +    imagenet_train = imagenet_input.ImageNetInput( +        width=FLAGS.image_width, +        height=FLAGS.image_height, +        resize=False if (FLAGS.model == 'logo') else True, +        is_training=True, +        data_dir=FLAGS.bucket+FLAGS.data if FLAGS.tpu else FLAGS.data, +        per_core_batch_size=per_core_batch_size) +    logging.info('Training model using real data in directory "%s".', +                 FLAGS.data) +    # If evaluating complete_eval, we feed the inputs from a Python generator, +    # so we need to build a single batch for all of the cores, which will be +    # split on TPU. +    per_core_batch_size = ( +        batch_size if (FLAGS.complete_eval or not FLAGS.tpu) else PER_CORE_BATCH_SIZE) +    imagenet_validation = imagenet_input.ImageNetInput( +        FLAGS.image_width, FLAGS.image_height, +        resize=False if (FLAGS.model == 'logo') else True, +        is_training=False, +        data_dir=FLAGS.bucket+FLAGS.data if FLAGS.tpu else FLAGS.data, +        per_core_batch_size=per_core_batch_size) + +    if FLAGS.evalonly: +        validation_epochs= [420] +        logging.info("Only running a single validation epoch") +    else: +        validation_epochs=[ 3, 10, 30, 60, 90] +        logging.info("Validation will be run on epochs %s", str(validation_epochs)) + +    eval_callback = eval_utils.TensorBoardWithValidation( +            log_dir=FLAGS.log, +            validation_imagenet_input=imagenet_validation, +            validation_steps=val_steps, +            validation_epochs=validation_epochs, +            write_images=True, +            write_graph=True, +            plot_wrong=FLAGS.plot_wrong, +            plot_cm=FLAGS.plot_cm, +            plot_pr=FLAGS.plot_pr, +            classes=classes, +            complete_eval=FLAGS.complete_eval) + +    callbacks = [ +        tf.keras.callbacks.ModelCheckpoint(FLAGS.log+"/weights.{epoch:02d}-{sparse_categorical_accuracy:.2f}.hdf5",  +            monitor='sparse_categorical_accuracy', verbose=1,  +            save_best_only=True, save_weights_only=True, mode='auto'), +        LearningRateBatchScheduler(schedule=learning_rate_schedule, train_steps_per_epoch=train_steps_per_epoch, base_learning_rate=FLAGS.lr), +        eval_callback +                    ] + +    if FLAGS.tpu: +        model_in = imagenet_train.input_fn  +    else: +        model_in = imagenet_train.input_fn() + +    preloaded_weights = [] +    for layer in model.layers: +        preloaded_weights.append(layer.get_weights()) + +    if FLAGS.weights: +        weights_file = os.path.join(FLAGS.weights) +        logging.info('Loading trained weights from %s', weights_file) +        model.load_weights(weights_file, by_name=FLAGS.weights_by_name) +        if FLAGS.weights2: +            weights2_file = os.path.join(FLAGS.weights2) +            logging.info('Loading secondary trained weights from %s', weights2_file) +            model.load_weights(weights2_file, by_name=FLAGS.weights_by_name) +    else: +        if FLAGS.weights2: +            logging.debug("Ignoring --weights2 flag as no --weights") +        weights_file = os.path.join(DEFAULT_WEIGHTS_H5) + +    # Check if we loaded weights +    for layer, pre in zip(model.layers, preloaded_weights): +        weights = layer.get_weights() + +        populated=True +        if weights: +          for weight, pr in zip(weights, pre): +            if np.array_equal(weight, pr): +                populated=False + +        if not populated: +            logging.warning('Layer %s not populated with weights!', layer.name) +         +    if FLAGS.evalonly: +        eval_callback.set_model(model) +        eval_callback.on_epoch_end(420) +    else: +        model.fit(model_in, +                  epochs=FLAGS.epochs, +                  initial_epoch=FLAGS.initial_epoch, +                  class_weight = class_weights if FLAGS.class_weights else None, +                  steps_per_epoch=train_steps_per_epoch, +                  callbacks=callbacks) + +if __name__ == '__main__': +  tf.logging.set_verbosity(tf.logging.INFO) +  app.run(main)  | 
