summaryrefslogtreecommitdiff
path: root/resnet50.py
diff options
context:
space:
mode:
Diffstat (limited to 'resnet50.py')
-rwxr-xr-xresnet50.py348
1 files changed, 348 insertions, 0 deletions
diff --git a/resnet50.py b/resnet50.py
new file mode 100755
index 0000000..e063bfc
--- /dev/null
+++ b/resnet50.py
@@ -0,0 +1,348 @@
+#!/usr/bin/python3
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""ResNet-50 implemented with Keras running on Cloud TPUs.
+
+This file shows how you can run ResNet-50 on a Cloud TPU using the TensorFlow
+Keras support. This is configured for ImageNet (e.g. 1000 classes), but you can
+easily adapt to your own datasets by changing the code appropriately.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import app
+from absl import flags
+from absl import logging
+import numpy as np
+import tensorflow as tf
+
+import eval_utils
+import imagenet_input
+import models
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.optimizer_v2 import gradient_descent, adam
+
+try:
+ import h5py as _ # pylint: disable=g-import-not-at-top
+ HAS_H5PY = True
+except ImportError:
+ logging.warning('`h5py` is not installed. Please consider installing it '
+ 'to save weights for long-running training.')
+ HAS_H5PY = False
+
+
+# Imagenet training and test data sets.
+
+DEF_IMAGE_WIDTH = 320
+DEF_IMAGE_HEIGHT = 240
+DEF_EPOCHS = 90 # Standard imagenet training regime.
+
+# Training hyperparameters.
+NUM_CORES = 8
+PER_CORE_BATCH_SIZE = 4
+CPU_BATCH_SIZE = 1
+BASE_LEARNING_RATE = 1e-3
+# Learning rate schedule
+LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
+ (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
+]
+
+DEFAULT_WEIGHTS_H5 = 'resnet50_weights.h5'
+DEFAULT_LOG_DIR = '/tmp/netcraft'
+DEFAULT_BUCKET = 'gs://netcraft/'
+
+flags.DEFINE_float('lr', BASE_LEARNING_RATE, 'Defines the step size when training')
+flags.DEFINE_integer('epochs', DEF_EPOCHS, 'Number of epochs until which to train')
+flags.DEFINE_integer('split_epochs', 1, 'Split epochs into smaller bits, helps save weights')
+flags.DEFINE_integer('initial_epoch', 0, 'Epoch from which to start, useful when resuming training')
+flags.DEFINE_integer('image_width', DEF_IMAGE_WIDTH, '')
+flags.DEFINE_integer('image_height', DEF_IMAGE_HEIGHT, '')
+flags.DEFINE_string('weights', None, 'Use saved weights')
+flags.DEFINE_string('weights2', None, 'Use another saved weights')
+flags.DEFINE_string('bucket', DEFAULT_BUCKET, 'Bucket to use')
+flags.DEFINE_string('tpu', None, 'Name of the TPU to use.')
+flags.DEFINE_string('data', None, 'Path to training and testing data.')
+flags.DEFINE_string('model', "resnet", 'Which logo to use (resnet, or logo)')
+flags.DEFINE_string(
+ 'log', DEFAULT_LOG_DIR,
+ ('The directory where the model weights and training/evaluation summaries '
+ 'are stored. If not specified, save to /tmp/netcraft.'))
+flags.DEFINE_bool(
+ 'complete_eval', True,
+ 'Eval both top 1 and top 5 accuracy. Otherwise, only eval top 1 accuracy. '
+ 'Furthemore generate confusion matrixes and save softmax values in log_dir')
+flags.DEFINE_bool('evalonly', False, 'Only run eval with given weights, do not train')
+flags.DEFINE_bool('class_weights', False, 'Use class weights to deal with imbalanced dataset')
+flags.DEFINE_integer('benign_multiplier', 1, 'Multiplier for weigh tof benign class')
+flags.DEFINE_bool('plot_wrong', False, 'Plot false images in tensorboard, make eval slower')
+flags.DEFINE_bool('plot_cm', True, 'Plot confusion matrix in tensorboard')
+flags.DEFINE_bool('plot_pr', True, 'Plot precision recall in tensorboard')
+flags.DEFINE_bool('weights_by_name', False, 'Load weights by name, this allows loading weights with an incompatible fully '+
+ 'connect layer i.e. a different number of targets. The FC layer is randomly initiated and needs to be trained.')
+
+FLAGS = flags.FLAGS
+
+def learning_rate_schedule(current_epoch, current_batch, train_steps_per_epoch, base_learning_rate):
+ """Handles linear scaling rule, gradual warmup, and LR decay.
+
+ The learning rate starts at 0, then it increases linearly per step.
+ After 5 epochs we reach the base learning rate (scaled to account
+ for batch size).
+ After 30, 60 and 80 epochs the learning rate is divided by 10.
+ After 90 epochs training stops and the LR is set to 0. This ensures
+ that we train for exactly 90 epochs for reproducibility.
+
+ Args:
+ current_epoch: integer, current epoch indexed from 0.
+ current_batch: integer, current batch in the current epoch, indexed from 0.
+
+ Returns:
+ Adjusted learning rate.
+ """
+ epoch = current_epoch + float(current_batch) / train_steps_per_epoch
+ warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
+ if epoch < warmup_end_epoch:
+ # Learning rate increases linearly per step.
+ return base_learning_rate * warmup_lr_multiplier * epoch / warmup_end_epoch
+ for mult, start_epoch in LR_SCHEDULE:
+ if epoch >= start_epoch:
+ learning_rate = base_learning_rate * mult
+ else:
+ break
+ return learning_rate
+
+
+class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
+ """Callback to update learning rate on every batch (not epoch boundaries).
+
+ N.B. Only support Keras optimizers, not TF optimizers.
+
+ Args:
+ schedule: a function that takes an epoch index and a batch index as input
+ (both integer, indexed from 0) and returns a new learning rate as
+ output (float).
+ """
+
+ def __init__(self, schedule, train_steps_per_epoch, base_learning_rate):
+ super(LearningRateBatchScheduler, self).__init__()
+ self.base_lr = base_learning_rate
+ self.schedule = schedule
+ self.train_steps_per_epoch = train_steps_per_epoch
+ self.epochs = -1
+ self.prev_lr = -1
+
+ def on_epoch_begin(self, epoch, logs=None):
+ if not hasattr(self.model.optimizer, 'lr'):
+ raise ValueError('Optimizer must have a "lr" attribute.')
+ self.epochs += 1
+
+ def on_batch_begin(self, batch, logs=None):
+ lr = self.schedule(self.epochs, batch, self.train_steps_per_epoch, self.base_lr)
+ if not isinstance(lr, (float, np.float32, np.float64)):
+ raise ValueError('The output of the "schedule" function should be float.')
+ if lr != self.prev_lr:
+ K.set_value(self.model.optimizer.lr, lr)
+ self.prev_lr = lr
+ logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler change '
+ 'learning rate to %s.', self.epochs, batch, lr)
+
+
+def main(argv):
+ if FLAGS.data:
+ dinfo = np.load(os.path.join(FLAGS.data, 'dinfo.npz'), allow_pickle=True)
+ classes = dinfo['classes']
+ num_classes = len(classes)
+ train_cnt = dinfo['train_cnt'] # 1141 # 50273 # Approximate number of images.
+ val_cnt = dinfo['val_cnt'] # 488 # 12560 # Number of images.
+ class_weights = dinfo['class_weights'].tolist()
+ #class_weights = class_weights[()] # Unpack 0d np.array
+
+ if FLAGS.class_weights and FLAGS.benign_multiplier != 1:
+ benign_class = np.squeeze(np.where(classes=='benign'))
+ if benign_class:
+ benign_class = np.asscalar(benign_class)
+ class_weights[benign_class] *= FLAGS.benign_multiplier
+ else:
+ logging.warning("Could not find benign class. Ignoring benign multiplier.")
+ else:
+ train_cnt = 10e5
+ val_cnt = 10e4
+ num_classes = 10e2
+
+ if FLAGS.tpu:
+ batch_size = NUM_CORES * PER_CORE_BATCH_SIZE
+ else:
+ batch_size = CPU_BATCH_SIZE
+
+ train_steps_per_epoch = int(train_cnt / (batch_size * FLAGS.split_epochs))
+ val_steps = int(val_cnt // batch_size )
+
+ logging.info("Using %d training images and %d for validation", train_cnt, val_cnt)
+
+ if FLAGS.model == 'resnet':
+ logging.info('Building Keras ResNet-50 model')
+ model = models.ResNet50(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=num_classes)
+ elif FLAGS.model == 'combined':
+ logging.info('Building Keras ResNet-50 + LOGO model')
+ model = models.get_logores_model(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=num_classes, resnet_trainable=False)
+ elif FLAGS.model == 'combined_trainable':
+ logging.info('Building Keras ResNet-50 + LOGO model')
+ model = models.get_logores_model(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=num_classes, resnet_trainable=True)
+ elif FLAGS.model == 'logo':
+ logging.info('Building LogoNet model')
+ model = models.get_logo_model(width=None, height=None, num_classes=num_classes, base_trainable=True)
+ elif FLAGS.model == 'logo_extended':
+ logging.info('Building LogoNet model')
+ model = models.get_logo_model(width=FLAGS.image_width, height=FLAGS.image_height, base_trainable=False, num_classes=num_classes)
+ elif FLAGS.model == 'logo_new':
+ logging.info('Building LogoNet model')
+ model = models.get_logo_model_new(width=FLAGS.image_width, height=FLAGS.image_height, base_trainable=False, num_classes=num_classes)
+ elif FLAGS.model == 'logo_extended_trainable':
+ logging.info('Building LogoNet model')
+ model = models.get_logo_model(width=FLAGS.image_width, height=FLAGS.image_height, base_trainable=True, num_classes=num_classes)
+ else:
+ return 'Only valid models are resnet and logo'
+
+ if FLAGS.tpu:
+ logging.info('Converting from CPU to TPU model.')
+ resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
+ strategy = tf.contrib.tpu.TPUDistributionStrategy(resolver)
+ model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
+
+ logging.info('Compiling model.')
+ model.compile(
+ optimizer=adam.Adam(learning_rate=FLAGS.lr),
+ loss='sparse_categorical_crossentropy',
+ metrics=['sparse_categorical_accuracy'])
+
+ if FLAGS.data is None:
+ training_images = np.random.randn(
+ batch_size, FLAGS.image_height, FLAGS.image_width, 3).astype(np.float32)
+ training_labels = np.random.randint(num_classes, size=batch_size,
+ dtype=np.int32)
+ logging.info('Training model using synthetica data, use --data flag to provided real data.')
+ model.fit(
+ training_images,
+ training_labels,
+ epochs=FLAGS.epochs,
+ initial_epoch=FLAGS.initial_epoch,
+ batch_size=batch_size)
+ logging.info('Evaluating the model on synthetic data.')
+ model.evaluate(training_images, training_labels, verbose=0)
+ else:
+ per_core_batch_size = PER_CORE_BATCH_SIZE if FLAGS.tpu else CPU_BATCH_SIZE
+ imagenet_train = imagenet_input.ImageNetInput(
+ width=FLAGS.image_width,
+ height=FLAGS.image_height,
+ resize=False if (FLAGS.model == 'logo') else True,
+ is_training=True,
+ data_dir=FLAGS.bucket+FLAGS.data if FLAGS.tpu else FLAGS.data,
+ per_core_batch_size=per_core_batch_size)
+ logging.info('Training model using real data in directory "%s".',
+ FLAGS.data)
+ # If evaluating complete_eval, we feed the inputs from a Python generator,
+ # so we need to build a single batch for all of the cores, which will be
+ # split on TPU.
+ per_core_batch_size = (
+ batch_size if (FLAGS.complete_eval or not FLAGS.tpu) else PER_CORE_BATCH_SIZE)
+ imagenet_validation = imagenet_input.ImageNetInput(
+ FLAGS.image_width, FLAGS.image_height,
+ resize=False if (FLAGS.model == 'logo') else True,
+ is_training=False,
+ data_dir=FLAGS.bucket+FLAGS.data if FLAGS.tpu else FLAGS.data,
+ per_core_batch_size=per_core_batch_size)
+
+ if FLAGS.evalonly:
+ validation_epochs= [420]
+ logging.info("Only running a single validation epoch")
+ else:
+ validation_epochs=[ 3, 10, 30, 60, 90]
+ logging.info("Validation will be run on epochs %s", str(validation_epochs))
+
+ eval_callback = eval_utils.TensorBoardWithValidation(
+ log_dir=FLAGS.log,
+ validation_imagenet_input=imagenet_validation,
+ validation_steps=val_steps,
+ validation_epochs=validation_epochs,
+ write_images=True,
+ write_graph=True,
+ plot_wrong=FLAGS.plot_wrong,
+ plot_cm=FLAGS.plot_cm,
+ plot_pr=FLAGS.plot_pr,
+ classes=classes,
+ complete_eval=FLAGS.complete_eval)
+
+ callbacks = [
+ tf.keras.callbacks.ModelCheckpoint(FLAGS.log+"/weights.{epoch:02d}-{sparse_categorical_accuracy:.2f}.hdf5",
+ monitor='sparse_categorical_accuracy', verbose=1,
+ save_best_only=True, save_weights_only=True, mode='auto'),
+ LearningRateBatchScheduler(schedule=learning_rate_schedule, train_steps_per_epoch=train_steps_per_epoch, base_learning_rate=FLAGS.lr),
+ eval_callback
+ ]
+
+ if FLAGS.tpu:
+ model_in = imagenet_train.input_fn
+ else:
+ model_in = imagenet_train.input_fn()
+
+ preloaded_weights = []
+ for layer in model.layers:
+ preloaded_weights.append(layer.get_weights())
+
+ if FLAGS.weights:
+ weights_file = os.path.join(FLAGS.weights)
+ logging.info('Loading trained weights from %s', weights_file)
+ model.load_weights(weights_file, by_name=FLAGS.weights_by_name)
+ if FLAGS.weights2:
+ weights2_file = os.path.join(FLAGS.weights2)
+ logging.info('Loading secondary trained weights from %s', weights2_file)
+ model.load_weights(weights2_file, by_name=FLAGS.weights_by_name)
+ else:
+ if FLAGS.weights2:
+ logging.debug("Ignoring --weights2 flag as no --weights")
+ weights_file = os.path.join(DEFAULT_WEIGHTS_H5)
+
+ # Check if we loaded weights
+ for layer, pre in zip(model.layers, preloaded_weights):
+ weights = layer.get_weights()
+
+ populated=True
+ if weights:
+ for weight, pr in zip(weights, pre):
+ if np.array_equal(weight, pr):
+ populated=False
+
+ if not populated:
+ logging.warning('Layer %s not populated with weights!', layer.name)
+
+ if FLAGS.evalonly:
+ eval_callback.set_model(model)
+ eval_callback.on_epoch_end(420)
+ else:
+ model.fit(model_in,
+ epochs=FLAGS.epochs,
+ initial_epoch=FLAGS.initial_epoch,
+ class_weight = class_weights if FLAGS.class_weights else None,
+ steps_per_epoch=train_steps_per_epoch,
+ callbacks=callbacks)
+
+if __name__ == '__main__':
+ tf.logging.set_verbosity(tf.logging.INFO)
+ app.run(main)