summaryrefslogtreecommitdiff
path: root/resnet50.py
blob: e063bfce4a57c461fba2555b9c3b77f899356612 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
#!/usr/bin/python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""ResNet-50 implemented with Keras running on Cloud TPUs.

This file shows how you can run ResNet-50 on a Cloud TPU using the TensorFlow
Keras support. This is configured for ImageNet (e.g. 1000 classes), but you can
easily adapt to your own datasets by changing the code appropriately.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow as tf

import eval_utils
import imagenet_input
import models
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.optimizer_v2 import gradient_descent, adam

try:
  import h5py as _  # pylint: disable=g-import-not-at-top
  HAS_H5PY = True
except ImportError:
  logging.warning('`h5py` is not installed. Please consider installing it '
                  'to save weights for long-running training.')
  HAS_H5PY = False


# Imagenet training and test data sets.

DEF_IMAGE_WIDTH  = 320
DEF_IMAGE_HEIGHT = 240
DEF_EPOCHS = 90  # Standard imagenet training regime.

# Training hyperparameters.
NUM_CORES = 8
PER_CORE_BATCH_SIZE = 4
CPU_BATCH_SIZE = 1
BASE_LEARNING_RATE = 1e-3
# Learning rate schedule
LR_SCHEDULE = [    # (multiplier, epoch to start) tuples
    (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]

DEFAULT_WEIGHTS_H5 = 'resnet50_weights.h5'
DEFAULT_LOG_DIR = '/tmp/netcraft'
DEFAULT_BUCKET = 'gs://netcraft/'

flags.DEFINE_float('lr', BASE_LEARNING_RATE, 'Defines the step size when training')
flags.DEFINE_integer('epochs', DEF_EPOCHS, 'Number of epochs until which to train')
flags.DEFINE_integer('split_epochs', 1, 'Split epochs into smaller bits, helps save weights')
flags.DEFINE_integer('initial_epoch', 0, 'Epoch from which to start, useful when resuming training')
flags.DEFINE_integer('image_width', DEF_IMAGE_WIDTH, '')
flags.DEFINE_integer('image_height', DEF_IMAGE_HEIGHT, '')
flags.DEFINE_string('weights', None, 'Use saved weights')
flags.DEFINE_string('weights2', None, 'Use another saved weights')
flags.DEFINE_string('bucket', DEFAULT_BUCKET, 'Bucket to use')
flags.DEFINE_string('tpu', None, 'Name of the TPU to use.')
flags.DEFINE_string('data', None, 'Path to training and testing data.')
flags.DEFINE_string('model', "resnet", 'Which logo to use (resnet, or logo)')
flags.DEFINE_string(
    'log', DEFAULT_LOG_DIR,
    ('The directory where the model weights and training/evaluation summaries '
     'are stored. If not specified, save to /tmp/netcraft.'))
flags.DEFINE_bool(
    'complete_eval', True,
    'Eval both top 1 and top 5 accuracy. Otherwise, only eval top 1 accuracy. '
    'Furthemore generate confusion matrixes and save softmax values in log_dir')
flags.DEFINE_bool('evalonly', False, 'Only run eval with given weights, do not train')
flags.DEFINE_bool('class_weights', False, 'Use class weights to deal with imbalanced dataset')
flags.DEFINE_integer('benign_multiplier', 1, 'Multiplier for weigh tof benign class')
flags.DEFINE_bool('plot_wrong', False, 'Plot false images in tensorboard, make eval slower')
flags.DEFINE_bool('plot_cm', True, 'Plot confusion matrix in tensorboard')
flags.DEFINE_bool('plot_pr', True, 'Plot precision recall in tensorboard')
flags.DEFINE_bool('weights_by_name', False, 'Load weights by name, this allows loading weights with an incompatible fully '+
                  'connect layer i.e. a different number of targets. The FC layer is randomly initiated and needs to be trained.')

FLAGS = flags.FLAGS

def learning_rate_schedule(current_epoch, current_batch, train_steps_per_epoch, base_learning_rate):
  """Handles linear scaling rule, gradual warmup, and LR decay.

  The learning rate starts at 0, then it increases linearly per step.
  After 5 epochs we reach the base learning rate (scaled to account
    for batch size).
  After 30, 60 and 80 epochs the learning rate is divided by 10.
  After 90 epochs training stops and the LR is set to 0. This ensures
    that we train for exactly 90 epochs for reproducibility.

  Args:
    current_epoch: integer, current epoch indexed from 0.
    current_batch: integer, current batch in the current epoch, indexed from 0.

  Returns:
    Adjusted learning rate.
  """
  epoch = current_epoch + float(current_batch) / train_steps_per_epoch
  warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
  if epoch < warmup_end_epoch:
    # Learning rate increases linearly per step.
    return base_learning_rate * warmup_lr_multiplier * epoch / warmup_end_epoch
  for mult, start_epoch in LR_SCHEDULE:
    if epoch >= start_epoch:
      learning_rate = base_learning_rate * mult
    else:
      break
  return learning_rate


class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
  """Callback to update learning rate on every batch (not epoch boundaries).

  N.B. Only support Keras optimizers, not TF optimizers.

  Args:
      schedule: a function that takes an epoch index and a batch index as input
          (both integer, indexed from 0) and returns a new learning rate as
          output (float).
  """

  def __init__(self, schedule, train_steps_per_epoch, base_learning_rate):
    super(LearningRateBatchScheduler, self).__init__()
    self.base_lr = base_learning_rate
    self.schedule = schedule
    self.train_steps_per_epoch = train_steps_per_epoch
    self.epochs = -1
    self.prev_lr = -1

  def on_epoch_begin(self, epoch, logs=None):
    if not hasattr(self.model.optimizer, 'lr'):
      raise ValueError('Optimizer must have a "lr" attribute.')
    self.epochs += 1

  def on_batch_begin(self, batch, logs=None):
    lr = self.schedule(self.epochs, batch, self.train_steps_per_epoch, self.base_lr)
    if not isinstance(lr, (float, np.float32, np.float64)):
      raise ValueError('The output of the "schedule" function should be float.')
    if lr != self.prev_lr:
      K.set_value(self.model.optimizer.lr, lr)
      self.prev_lr = lr
      logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler change '
                    'learning rate to %s.', self.epochs, batch, lr)


def main(argv):
  if FLAGS.data:
      dinfo = np.load(os.path.join(FLAGS.data, 'dinfo.npz'), allow_pickle=True)
      classes = dinfo['classes']
      num_classes = len(classes)
      train_cnt = dinfo['train_cnt'] # 1141 # 50273 # Approximate number of images.
      val_cnt = dinfo['val_cnt'] # 488 # 12560  # Number of images.
      class_weights = dinfo['class_weights'].tolist()
      #class_weights = class_weights[()] # Unpack 0d np.array

      if FLAGS.class_weights and FLAGS.benign_multiplier != 1:
          benign_class = np.squeeze(np.where(classes=='benign'))
          if benign_class:
              benign_class = np.asscalar(benign_class)
              class_weights[benign_class] *= FLAGS.benign_multiplier
          else:
              logging.warning("Could not find benign class. Ignoring benign multiplier.")
  else:
      train_cnt = 10e5
      val_cnt = 10e4
      num_classes = 10e2

  if FLAGS.tpu:
      batch_size = NUM_CORES * PER_CORE_BATCH_SIZE
  else:
      batch_size = CPU_BATCH_SIZE

  train_steps_per_epoch = int(train_cnt / (batch_size * FLAGS.split_epochs))
  val_steps = int(val_cnt // batch_size )

  logging.info("Using %d training images and %d for validation", train_cnt, val_cnt)

  if FLAGS.model == 'resnet':
      logging.info('Building Keras ResNet-50 model')
      model = models.ResNet50(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=num_classes)
  elif FLAGS.model == 'combined':
      logging.info('Building Keras ResNet-50 + LOGO model')
      model = models.get_logores_model(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=num_classes, resnet_trainable=False)
  elif FLAGS.model == 'combined_trainable':
      logging.info('Building Keras ResNet-50 + LOGO model')
      model = models.get_logores_model(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=num_classes, resnet_trainable=True)
  elif FLAGS.model == 'logo':
      logging.info('Building LogoNet model')
      model = models.get_logo_model(width=None, height=None, num_classes=num_classes, base_trainable=True)
  elif FLAGS.model == 'logo_extended':
      logging.info('Building LogoNet model')
      model = models.get_logo_model(width=FLAGS.image_width, height=FLAGS.image_height, base_trainable=False, num_classes=num_classes)
  elif FLAGS.model == 'logo_new':
      logging.info('Building LogoNet model')
      model = models.get_logo_model_new(width=FLAGS.image_width, height=FLAGS.image_height, base_trainable=False, num_classes=num_classes)
  elif FLAGS.model == 'logo_extended_trainable':
      logging.info('Building LogoNet model')
      model = models.get_logo_model(width=FLAGS.image_width, height=FLAGS.image_height, base_trainable=True, num_classes=num_classes)
  else:
      return 'Only valid models are resnet and logo'

  if FLAGS.tpu:
    logging.info('Converting from CPU to TPU model.')
    resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    strategy = tf.contrib.tpu.TPUDistributionStrategy(resolver)
    model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)

  logging.info('Compiling model.')
  model.compile(
      optimizer=adam.Adam(learning_rate=FLAGS.lr),
      loss='sparse_categorical_crossentropy',
      metrics=['sparse_categorical_accuracy'])

  if FLAGS.data is None:
    training_images = np.random.randn(
        batch_size, FLAGS.image_height, FLAGS.image_width, 3).astype(np.float32)
    training_labels = np.random.randint(num_classes, size=batch_size,
                                        dtype=np.int32)
    logging.info('Training model using synthetica data, use --data flag to provided real data.')
    model.fit(
        training_images,
        training_labels,
        epochs=FLAGS.epochs,
        initial_epoch=FLAGS.initial_epoch,
        batch_size=batch_size)
    logging.info('Evaluating the model on synthetic data.')
    model.evaluate(training_images, training_labels, verbose=0)
  else:
    per_core_batch_size = PER_CORE_BATCH_SIZE if FLAGS.tpu else CPU_BATCH_SIZE
    imagenet_train = imagenet_input.ImageNetInput(
        width=FLAGS.image_width,
        height=FLAGS.image_height,
        resize=False if (FLAGS.model == 'logo') else True,
        is_training=True,
        data_dir=FLAGS.bucket+FLAGS.data if FLAGS.tpu else FLAGS.data,
        per_core_batch_size=per_core_batch_size)
    logging.info('Training model using real data in directory "%s".',
                 FLAGS.data)
    # If evaluating complete_eval, we feed the inputs from a Python generator,
    # so we need to build a single batch for all of the cores, which will be
    # split on TPU.
    per_core_batch_size = (
        batch_size if (FLAGS.complete_eval or not FLAGS.tpu) else PER_CORE_BATCH_SIZE)
    imagenet_validation = imagenet_input.ImageNetInput(
        FLAGS.image_width, FLAGS.image_height,
        resize=False if (FLAGS.model == 'logo') else True,
        is_training=False,
        data_dir=FLAGS.bucket+FLAGS.data if FLAGS.tpu else FLAGS.data,
        per_core_batch_size=per_core_batch_size)

    if FLAGS.evalonly:
        validation_epochs= [420]
        logging.info("Only running a single validation epoch")
    else:
        validation_epochs=[ 3, 10, 30, 60, 90]
        logging.info("Validation will be run on epochs %s", str(validation_epochs))

    eval_callback = eval_utils.TensorBoardWithValidation(
            log_dir=FLAGS.log,
            validation_imagenet_input=imagenet_validation,
            validation_steps=val_steps,
            validation_epochs=validation_epochs,
            write_images=True,
            write_graph=True,
            plot_wrong=FLAGS.plot_wrong,
            plot_cm=FLAGS.plot_cm,
            plot_pr=FLAGS.plot_pr,
            classes=classes,
            complete_eval=FLAGS.complete_eval)

    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(FLAGS.log+"/weights.{epoch:02d}-{sparse_categorical_accuracy:.2f}.hdf5", 
            monitor='sparse_categorical_accuracy', verbose=1, 
            save_best_only=True, save_weights_only=True, mode='auto'),
        LearningRateBatchScheduler(schedule=learning_rate_schedule, train_steps_per_epoch=train_steps_per_epoch, base_learning_rate=FLAGS.lr),
        eval_callback
                    ]

    if FLAGS.tpu:
        model_in = imagenet_train.input_fn 
    else:
        model_in = imagenet_train.input_fn()

    preloaded_weights = []
    for layer in model.layers:
        preloaded_weights.append(layer.get_weights())

    if FLAGS.weights:
        weights_file = os.path.join(FLAGS.weights)
        logging.info('Loading trained weights from %s', weights_file)
        model.load_weights(weights_file, by_name=FLAGS.weights_by_name)
        if FLAGS.weights2:
            weights2_file = os.path.join(FLAGS.weights2)
            logging.info('Loading secondary trained weights from %s', weights2_file)
            model.load_weights(weights2_file, by_name=FLAGS.weights_by_name)
    else:
        if FLAGS.weights2:
            logging.debug("Ignoring --weights2 flag as no --weights")
        weights_file = os.path.join(DEFAULT_WEIGHTS_H5)

    # Check if we loaded weights
    for layer, pre in zip(model.layers, preloaded_weights):
        weights = layer.get_weights()

        populated=True
        if weights:
          for weight, pr in zip(weights, pre):
            if np.array_equal(weight, pr):
                populated=False

        if not populated:
            logging.warning('Layer %s not populated with weights!', layer.name)
        
    if FLAGS.evalonly:
        eval_callback.set_model(model)
        eval_callback.on_epoch_end(420)
    else:
        model.fit(model_in,
                  epochs=FLAGS.epochs,
                  initial_epoch=FLAGS.initial_epoch,
                  class_weight = class_weights if FLAGS.class_weights else None,
                  steps_per_epoch=train_steps_per_epoch,
                  callbacks=callbacks)

if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
  app.run(main)