#!/usr/bin/python3 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ResNet-50 implemented with Keras running on Cloud TPUs. This file shows how you can run ResNet-50 on a Cloud TPU using the TensorFlow Keras support. This is configured for ImageNet (e.g. 1000 classes), but you can easily adapt to your own datasets by changing the code appropriately. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os from absl import app from absl import flags from absl import logging import numpy as np import tensorflow as tf import eval_utils import imagenet_input import models from tensorflow.python.keras import backend as K from tensorflow.python.keras.optimizer_v2 import gradient_descent, adam try: import h5py as _ # pylint: disable=g-import-not-at-top HAS_H5PY = True except ImportError: logging.warning('`h5py` is not installed. Please consider installing it ' 'to save weights for long-running training.') HAS_H5PY = False # Imagenet training and test data sets. DEF_IMAGE_WIDTH = 320 DEF_IMAGE_HEIGHT = 240 DEF_EPOCHS = 90 # Standard imagenet training regime. # Training hyperparameters. NUM_CORES = 8 PER_CORE_BATCH_SIZE = 4 CPU_BATCH_SIZE = 1 BASE_LEARNING_RATE = 1e-3 # Learning rate schedule LR_SCHEDULE = [ # (multiplier, epoch to start) tuples (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80) ] DEFAULT_WEIGHTS_H5 = 'resnet50_weights.h5' DEFAULT_LOG_DIR = '/tmp/netcraft' DEFAULT_BUCKET = 'gs://netcraft/' flags.DEFINE_float('lr', BASE_LEARNING_RATE, 'Defines the step size when training') flags.DEFINE_integer('epochs', DEF_EPOCHS, 'Number of epochs until which to train') flags.DEFINE_integer('split_epochs', 1, 'Split epochs into smaller bits, helps save weights') flags.DEFINE_integer('initial_epoch', 0, 'Epoch from which to start, useful when resuming training') flags.DEFINE_integer('image_width', DEF_IMAGE_WIDTH, '') flags.DEFINE_integer('image_height', DEF_IMAGE_HEIGHT, '') flags.DEFINE_string('weights', None, 'Use saved weights') flags.DEFINE_string('weights2', None, 'Use another saved weights') flags.DEFINE_string('bucket', DEFAULT_BUCKET, 'Bucket to use') flags.DEFINE_string('tpu', None, 'Name of the TPU to use.') flags.DEFINE_string('data', None, 'Path to training and testing data.') flags.DEFINE_string('model', "resnet", 'Which logo to use (resnet, or logo)') flags.DEFINE_string( 'log', DEFAULT_LOG_DIR, ('The directory where the model weights and training/evaluation summaries ' 'are stored. If not specified, save to /tmp/netcraft.')) flags.DEFINE_bool( 'complete_eval', True, 'Eval both top 1 and top 5 accuracy. Otherwise, only eval top 1 accuracy. ' 'Furthemore generate confusion matrixes and save softmax values in log_dir') flags.DEFINE_bool('evalonly', False, 'Only run eval with given weights, do not train') flags.DEFINE_bool('class_weights', False, 'Use class weights to deal with imbalanced dataset') flags.DEFINE_integer('benign_multiplier', 1, 'Multiplier for weigh tof benign class') flags.DEFINE_bool('plot_wrong', False, 'Plot false images in tensorboard, make eval slower') flags.DEFINE_bool('plot_cm', True, 'Plot confusion matrix in tensorboard') flags.DEFINE_bool('plot_pr', True, 'Plot precision recall in tensorboard') flags.DEFINE_bool('weights_by_name', False, 'Load weights by name, this allows loading weights with an incompatible fully '+ 'connect layer i.e. a different number of targets. The FC layer is randomly initiated and needs to be trained.') FLAGS = flags.FLAGS def learning_rate_schedule(current_epoch, current_batch, train_steps_per_epoch, base_learning_rate): """Handles linear scaling rule, gradual warmup, and LR decay. The learning rate starts at 0, then it increases linearly per step. After 5 epochs we reach the base learning rate (scaled to account for batch size). After 30, 60 and 80 epochs the learning rate is divided by 10. After 90 epochs training stops and the LR is set to 0. This ensures that we train for exactly 90 epochs for reproducibility. Args: current_epoch: integer, current epoch indexed from 0. current_batch: integer, current batch in the current epoch, indexed from 0. Returns: Adjusted learning rate. """ epoch = current_epoch + float(current_batch) / train_steps_per_epoch warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0] if epoch < warmup_end_epoch: # Learning rate increases linearly per step. return base_learning_rate * warmup_lr_multiplier * epoch / warmup_end_epoch for mult, start_epoch in LR_SCHEDULE: if epoch >= start_epoch: learning_rate = base_learning_rate * mult else: break return learning_rate class LearningRateBatchScheduler(tf.keras.callbacks.Callback): """Callback to update learning rate on every batch (not epoch boundaries). N.B. Only support Keras optimizers, not TF optimizers. Args: schedule: a function that takes an epoch index and a batch index as input (both integer, indexed from 0) and returns a new learning rate as output (float). """ def __init__(self, schedule, train_steps_per_epoch, base_learning_rate): super(LearningRateBatchScheduler, self).__init__() self.base_lr = base_learning_rate self.schedule = schedule self.train_steps_per_epoch = train_steps_per_epoch self.epochs = -1 self.prev_lr = -1 def on_epoch_begin(self, epoch, logs=None): if not hasattr(self.model.optimizer, 'lr'): raise ValueError('Optimizer must have a "lr" attribute.') self.epochs += 1 def on_batch_begin(self, batch, logs=None): lr = self.schedule(self.epochs, batch, self.train_steps_per_epoch, self.base_lr) if not isinstance(lr, (float, np.float32, np.float64)): raise ValueError('The output of the "schedule" function should be float.') if lr != self.prev_lr: K.set_value(self.model.optimizer.lr, lr) self.prev_lr = lr logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler change ' 'learning rate to %s.', self.epochs, batch, lr) def main(argv): if FLAGS.data: dinfo = np.load(os.path.join(FLAGS.data, 'dinfo.npz'), allow_pickle=True) classes = dinfo['classes'] num_classes = len(classes) train_cnt = dinfo['train_cnt'] # 1141 # 50273 # Approximate number of images. val_cnt = dinfo['val_cnt'] # 488 # 12560 # Number of images. class_weights = dinfo['class_weights'].tolist() #class_weights = class_weights[()] # Unpack 0d np.array if FLAGS.class_weights and FLAGS.benign_multiplier != 1: benign_class = np.squeeze(np.where(classes=='benign')) if benign_class: benign_class = np.asscalar(benign_class) class_weights[benign_class] *= FLAGS.benign_multiplier else: logging.warning("Could not find benign class. Ignoring benign multiplier.") else: train_cnt = 10e5 val_cnt = 10e4 num_classes = 10e2 if FLAGS.tpu: batch_size = NUM_CORES * PER_CORE_BATCH_SIZE else: batch_size = CPU_BATCH_SIZE train_steps_per_epoch = int(train_cnt / (batch_size * FLAGS.split_epochs)) val_steps = int(val_cnt // batch_size ) logging.info("Using %d training images and %d for validation", train_cnt, val_cnt) if FLAGS.model == 'resnet': logging.info('Building Keras ResNet-50 model') model = models.ResNet50(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=num_classes) elif FLAGS.model == 'combined': logging.info('Building Keras ResNet-50 + LOGO model') model = models.get_logores_model(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=num_classes, resnet_trainable=False) elif FLAGS.model == 'combined_trainable': logging.info('Building Keras ResNet-50 + LOGO model') model = models.get_logores_model(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=num_classes, resnet_trainable=True) elif FLAGS.model == 'logo': logging.info('Building LogoNet model') model = models.get_logo_model(width=None, height=None, num_classes=num_classes, base_trainable=True) elif FLAGS.model == 'logo_extended': logging.info('Building LogoNet model') model = models.get_logo_model(width=FLAGS.image_width, height=FLAGS.image_height, base_trainable=False, num_classes=num_classes) elif FLAGS.model == 'logo_new': logging.info('Building LogoNet model') model = models.get_logo_model_new(width=FLAGS.image_width, height=FLAGS.image_height, base_trainable=False, num_classes=num_classes) elif FLAGS.model == 'logo_extended_trainable': logging.info('Building LogoNet model') model = models.get_logo_model(width=FLAGS.image_width, height=FLAGS.image_height, base_trainable=True, num_classes=num_classes) else: return 'Only valid models are resnet and logo' if FLAGS.tpu: logging.info('Converting from CPU to TPU model.') resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu) strategy = tf.contrib.tpu.TPUDistributionStrategy(resolver) model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy) logging.info('Compiling model.') model.compile( optimizer=adam.Adam(learning_rate=FLAGS.lr), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) if FLAGS.data is None: training_images = np.random.randn( batch_size, FLAGS.image_height, FLAGS.image_width, 3).astype(np.float32) training_labels = np.random.randint(num_classes, size=batch_size, dtype=np.int32) logging.info('Training model using synthetica data, use --data flag to provided real data.') model.fit( training_images, training_labels, epochs=FLAGS.epochs, initial_epoch=FLAGS.initial_epoch, batch_size=batch_size) logging.info('Evaluating the model on synthetic data.') model.evaluate(training_images, training_labels, verbose=0) else: per_core_batch_size = PER_CORE_BATCH_SIZE if FLAGS.tpu else CPU_BATCH_SIZE imagenet_train = imagenet_input.ImageNetInput( width=FLAGS.image_width, height=FLAGS.image_height, resize=False if (FLAGS.model == 'logo') else True, is_training=True, data_dir=FLAGS.bucket+FLAGS.data if FLAGS.tpu else FLAGS.data, per_core_batch_size=per_core_batch_size) logging.info('Training model using real data in directory "%s".', FLAGS.data) # If evaluating complete_eval, we feed the inputs from a Python generator, # so we need to build a single batch for all of the cores, which will be # split on TPU. per_core_batch_size = ( batch_size if (FLAGS.complete_eval or not FLAGS.tpu) else PER_CORE_BATCH_SIZE) imagenet_validation = imagenet_input.ImageNetInput( FLAGS.image_width, FLAGS.image_height, resize=False if (FLAGS.model == 'logo') else True, is_training=False, data_dir=FLAGS.bucket+FLAGS.data if FLAGS.tpu else FLAGS.data, per_core_batch_size=per_core_batch_size) if FLAGS.evalonly: validation_epochs= [420] logging.info("Only running a single validation epoch") else: validation_epochs=[ 3, 10, 30, 60, 90] logging.info("Validation will be run on epochs %s", str(validation_epochs)) eval_callback = eval_utils.TensorBoardWithValidation( log_dir=FLAGS.log, validation_imagenet_input=imagenet_validation, validation_steps=val_steps, validation_epochs=validation_epochs, write_images=True, write_graph=True, plot_wrong=FLAGS.plot_wrong, plot_cm=FLAGS.plot_cm, plot_pr=FLAGS.plot_pr, classes=classes, complete_eval=FLAGS.complete_eval) callbacks = [ tf.keras.callbacks.ModelCheckpoint(FLAGS.log+"/weights.{epoch:02d}-{sparse_categorical_accuracy:.2f}.hdf5", monitor='sparse_categorical_accuracy', verbose=1, save_best_only=True, save_weights_only=True, mode='auto'), LearningRateBatchScheduler(schedule=learning_rate_schedule, train_steps_per_epoch=train_steps_per_epoch, base_learning_rate=FLAGS.lr), eval_callback ] if FLAGS.tpu: model_in = imagenet_train.input_fn else: model_in = imagenet_train.input_fn() preloaded_weights = [] for layer in model.layers: preloaded_weights.append(layer.get_weights()) if FLAGS.weights: weights_file = os.path.join(FLAGS.weights) logging.info('Loading trained weights from %s', weights_file) model.load_weights(weights_file, by_name=FLAGS.weights_by_name) if FLAGS.weights2: weights2_file = os.path.join(FLAGS.weights2) logging.info('Loading secondary trained weights from %s', weights2_file) model.load_weights(weights2_file, by_name=FLAGS.weights_by_name) else: if FLAGS.weights2: logging.debug("Ignoring --weights2 flag as no --weights") weights_file = os.path.join(DEFAULT_WEIGHTS_H5) # Check if we loaded weights for layer, pre in zip(model.layers, preloaded_weights): weights = layer.get_weights() populated=True if weights: for weight, pr in zip(weights, pre): if np.array_equal(weight, pr): populated=False if not populated: logging.warning('Layer %s not populated with weights!', layer.name) if FLAGS.evalonly: eval_callback.set_model(model) eval_callback.on_epoch_end(420) else: model.fit(model_in, epochs=FLAGS.epochs, initial_epoch=FLAGS.initial_epoch, class_weight = class_weights if FLAGS.class_weights else None, steps_per_epoch=train_steps_per_epoch, callbacks=callbacks) if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) app.run(main)