summaryrefslogtreecommitdiff
path: root/eval_utils.py
blob: bf00b0bad97c7b70cd4fbd24dd7fb83c2802aaa0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation utils for `KerasTPUmodel`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from six.moves import xrange
import sys

import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import optimizers
from tensorflow.python.keras import callbacks
from tensorflow.python.platform import tf_logging as logging
from sklearn.metrics import confusion_matrix

import matplotlib.pyplot as plt
import io
import os
import itertools
import scikitplot as skplt

from tqdm import trange

def save_softmax(log_dir, epoch, labels, predictions):
    location = os.path.join(log_dir, 'softmax' + str(epoch) + '.npz')
    np.savez(location, labels=labels, predictions=predictions)

def draw_graphs(self, log_dir, classes, y_true, y_probas, epoch):
    y_pred = np.argmax(y_probas, axis=1)
    if self._plot_cm:
        skplt.metrics.plot_confusion_matrix(y_true, y_pred, normalize=True)
        plot_to_tensorboard(log_dir, epoch, "model_projections", "confusion_matrix")
    if self._plot_pr:
        skplt.metrics.plot_precision_recall(y_true, y_probas)
        plot_to_tensorboard(log_dir, epoch, "model_projections", "pr_curve")

def plot_to_tensorboard(log_dir, epoch, model_projection, family_name):
    buf = io.BytesIO()
    plt.rcParams.update({'font.size': 5})
    plt.savefig(buf, dpi=250, format='png')
    buf.seek(0)
    image = tf.image.decode_png(buf.getvalue(), channels=3)
    image = tf.expand_dims(image, 0)

    summary_op = tf.summary.image(model_projection, image, max_outputs=1, family=family_name)
    writer = tf.summary.FileWriter(log_dir)
    writer.add_summary(summary_op.eval(session=K.get_session()), epoch)
    writer.close()

def draw_c_matrix(log_dir, c_matrix, classes, epoch, normalize=False):
    if normalize:
        c_matrix = c_matrix.astype('float') / c_matrix.sum(axis=1)[:, np.newaxis]
    plt.figure()
    plt.imshow(c_matrix, cmap=plt.cm.Blues)
    plt.xlabel('Predicted')
    plt.ylabel('True Label')
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f'
    thresh = c_matrix.max() / 2.
    for i, j in itertools.product(range(c_matrix.shape[0]), range(c_matrix.shape[1])):
        plt.text(j, i, format(c_matrix[i, j], fmt),
             horizontalalignment="center",
             color="white" if c_matrix[i, j] > thresh else "black")

    buf = io.BytesIO()
    plt.savefig(buf, dpi=500, format='png')
    buf.seek(0)
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    image = tf.expand_dims(image, 0)

    summary_op = tf.summary.image("model_projections", image, max_outputs=1, family='family_name')
    writer = tf.summary.FileWriter(log_dir)
    writer.add_summary(summary_op.eval(session=K.get_session()), epoch)
    writer.close()

def multi_top_k_accuracy(self, log_dir, model, evaluation_generator, eval_steps, classes, epoch, ks=(1, 5)):
  """Calculates top k accuracy for the given `k` values.

  Args:
    model: `KerasTPUModel` to evaluate.
    evaluation_generator: a Python generator to generate (features, labels) for
                          evaluation.
    eval_steps: int, number of evaluation steps.
    ks: a tuple of int, position values to calculate top k accurary.

  Returns:
    A dictionary containing top k accuracy for the given `k` values.
  """
  def _count_matched(classes, predictions, labels, ks):
    """Count number of pairs with label in any of top k predictions."""
    top_k_matched = dict.fromkeys(ks, 0)
    for prediction, label in zip(predictions, labels):
      for k in ks:
        top_k_predictions = np.argpartition(prediction, -k)[-k:]
        if label in top_k_predictions:
          top_k_matched[k] += 1

    return top_k_matched

  total = 0
  top_k_matched = dict.fromkeys(ks, 0)
  c_matrix = np.zeros((len(classes),len(classes)))
  all_labels = np.zeros((0,1))
  all_predictions = np.zeros((0,len(classes)))
  logging.info('There are %d validation steps', eval_steps )
  t = trange(eval_steps)
  for step in t:
    try:
      (features, labels) = next(evaluation_generator)
    except Exception as e:
      logging.debug(e)
      break
    predictions = model.predict_on_batch(features) # May be quicker
    # predictions = model.predict(features, batch_size=8)
    sorted_pred_args = np.flip(predictions.argsort(axis=1), axis=1)
    flat_predictions = sorted_pred_args[:,0]

    # Todo: clean this function, it is a mess

    # Print some falsely predicted images
    if self._plot_wrong:# and not (step+4) % 8:
        # Squeeze labels into same dimension and type as predictions
        sq_labels = np.squeeze(labels.astype(int))
        # If running in unity image per batch, squeeze squeezes one too many dimensions
        if sq_labels.shape == ():
            sq_labels = np.expand_dims(sq_labels, axis=0)
        failed_indexes = np.where(np.not_equal(flat_predictions, sq_labels))[0]
        limiter = 0
        for idx in failed_indexes:
            if limiter > 90:
                break
            limiter += 1
            predicted_class_name = classes[flat_predictions[idx]]
            true_class_name = classes[sq_labels[idx]]
            proba_range = range(3) # Show softmax for top 3 
            top_cl  = classes[sorted_pred_args[idx][proba_range]]
            probas = predictions[idx][sorted_pred_args[idx][proba_range]]
            if probas[0] > 0.9:
                top_3 = '\n'.join(cl + ": " + proba for cl, proba in zip(top_cl, probas.astype(str)))
                print("Predicted", flat_predictions[idx],
                      "True:",      sq_labels[idx],
                      "Proba:",     probas.astype(str)) 
                plt.clf()
                plt.imshow(features[idx].astype(int))
                plt.text(0, 0, top_3, size=9, va="bottom", bbox=dict(boxstyle="square", ec=(1., 0.5, 0.5), fc=(1., 0.8, 0.8),))
                plot_to_tensorboard(log_dir, epoch, "mislabled_images", "P_"+predicted_class_name+"_Tr_"+true_class_name)

    c_matrix += confusion_matrix(labels, flat_predictions, labels=range(len(classes)))
    batch_top_k_matched = _count_matched(classes, predictions, labels, ks)
    all_labels = np.vstack((all_labels, labels))
    all_predictions = np.vstack((all_predictions, predictions))
    for k, matched in batch_top_k_matched.items():
      top_k_matched[k] += matched
    total += len(labels)

    t.set_description("Top 1: %f" %  np.float_(top_k_matched[1]/float(total)))

  logging.info("Confusion matrix:")
  print(c_matrix)

  try:
      #draw_c_matrix(self._log_dir, c_matrix, self._targets, epoch, normalize=True)
      #draw_c_matrix(self._log_dir, c_matrix, self._targets, epoch, normalize=False)
      draw_graphs(self, log_dir, classes, all_labels, all_predictions, epoch)
  except:
      pass

  save_softmax(log_dir, epoch, all_labels, all_predictions)
  metrics = dict([('top_{0}_accuracy'.format(k), np.float_(matched / float(total)))
               for k, matched in top_k_matched.items()])
  print(metrics)
  return metrics


class TensorBoardWithValidation(callbacks.TensorBoard):
  """Extend TensorBoard Callback with validation .

  Validation is executed at the end of specified epochs, and the validation
  metrics are exported to tensorboard for visualization.

  Args:
      log_dir: the path of the directory where to save the log
          files to be parsed by TensorBoard.
      validation_imagenet_input: ImageNetInput for validation.
      validation_steps: total number of steps to validate.
      validation_epochs: a list of integers, epochs to run validation.
      eval_top_k_accuracy: boolean, if true, evaluate top k accuracies using
          multi_top_k_accuracy(). Otherwise, use model.evaluate().
          N.B. enabling this would significantly slow down the eval time due to
          using python generator for evaluation input.
      top_ks: a tuple of int, position values to calculate top k accurary. It's
          only used when eval_top_k_accuracy is true.
  """

  def __init__(self,
               log_dir,
               validation_imagenet_input,
               validation_steps,
               validation_epochs,
               write_graph,
               write_images,
               plot_wrong,
               plot_cm,
               plot_pr,
               classes,
               complete_eval,
               top_ks=(1, 5)):
    super(TensorBoardWithValidation, self).__init__(log_dir)
    self._validation_imagenet_input = validation_imagenet_input
    self._validation_steps = validation_steps
    self._validation_epochs = validation_epochs
    self._write_graph = write_graph
    self._write_images = write_images
    self._plot_wrong = plot_wrong
    self._plot_cm = plot_cm
    self._plot_pr = plot_pr
    self._complete_eval = complete_eval
    self._top_ks = top_ks
    self._targets = classes
    self._log_dir = log_dir

  def on_epoch_end(self, epoch, logs=None):
    if epoch in self._validation_epochs:
      logging.info('\nValidate in epoch %s', epoch)
      if self._complete_eval:
        logging.info("Running complete eval")
        score = multi_top_k_accuracy(
            self,
            self._log_dir,
            self.model,
            self._validation_imagenet_input.evaluation_generator(
                K.get_session()),
            self._validation_steps,
            self._targets,
            epoch,
            ks=self._top_ks)
        for metric_name, metric_value in score.items():
          logs['val_' + metric_name] = metric_value
      else:
        # evaluate() is executed as callbacks during the training. In this case,
        # _numpy_to_infeed_manager_list is not empty, so save it for
        # recovery at the end of evaluate call.
        # TODO(jingli): remove this monkey patch hack once the fix is included
        # in future TF release.
        original_numpy_to_infeed_manager_list = []
        if self.model._numpy_to_infeed_manager_list:
          original_numpy_to_infeed_manager_list = (
              self.model._numpy_to_infeed_manager_list)
          self.model._numpy_to_infeed_manager_list = []
        # Set _eval_function to None to enforce recompliation to use the newly
        # created dataset in self._validation_imagenet_input.input_fn in
        # evaluation.
        # pylint: disable=bare-except
        # pylint: disable=protected-access
        try:
          self.model._eval_function = None
        except:
          pass

        try:
          # In TF 1.12, _eval_function does not exist, only test_function
          # existed.
          self.model.test_function = None
        except:
          pass

        scores = self.model.evaluate(self._validation_imagenet_input.input_fn,
                                     steps=self._validation_steps)
        self.model._numpy_to_infeed_manager_list = (
            original_numpy_to_infeed_manager_list)
        for metric_name, metric_value in zip(self.model.metrics_names, scores):
          logging.info('Evaluation metric. %s: %s.', metric_name, metric_value)
          logs['val_' + metric_name] = metric_value
    # The parent callback is responsible to write the logs as events file.
    super(TensorBoardWithValidation, self).on_epoch_end(epoch, logs)