summaryrefslogtreecommitdiff
path: root/eval_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'eval_utils.py')
-rw-r--r--eval_utils.py293
1 files changed, 293 insertions, 0 deletions
diff --git a/eval_utils.py b/eval_utils.py
new file mode 100644
index 0000000..bf00b0b
--- /dev/null
+++ b/eval_utils.py
@@ -0,0 +1,293 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Evaluation utils for `KerasTPUmodel`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange
+import sys
+
+import tensorflow as tf
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import optimizers
+from tensorflow.python.keras import callbacks
+from tensorflow.python.platform import tf_logging as logging
+from sklearn.metrics import confusion_matrix
+
+import matplotlib.pyplot as plt
+import io
+import os
+import itertools
+import scikitplot as skplt
+
+from tqdm import trange
+
+def save_softmax(log_dir, epoch, labels, predictions):
+ location = os.path.join(log_dir, 'softmax' + str(epoch) + '.npz')
+ np.savez(location, labels=labels, predictions=predictions)
+
+def draw_graphs(self, log_dir, classes, y_true, y_probas, epoch):
+ y_pred = np.argmax(y_probas, axis=1)
+ if self._plot_cm:
+ skplt.metrics.plot_confusion_matrix(y_true, y_pred, normalize=True)
+ plot_to_tensorboard(log_dir, epoch, "model_projections", "confusion_matrix")
+ if self._plot_pr:
+ skplt.metrics.plot_precision_recall(y_true, y_probas)
+ plot_to_tensorboard(log_dir, epoch, "model_projections", "pr_curve")
+
+def plot_to_tensorboard(log_dir, epoch, model_projection, family_name):
+ buf = io.BytesIO()
+ plt.rcParams.update({'font.size': 5})
+ plt.savefig(buf, dpi=250, format='png')
+ buf.seek(0)
+ image = tf.image.decode_png(buf.getvalue(), channels=3)
+ image = tf.expand_dims(image, 0)
+
+ summary_op = tf.summary.image(model_projection, image, max_outputs=1, family=family_name)
+ writer = tf.summary.FileWriter(log_dir)
+ writer.add_summary(summary_op.eval(session=K.get_session()), epoch)
+ writer.close()
+
+def draw_c_matrix(log_dir, c_matrix, classes, epoch, normalize=False):
+ if normalize:
+ c_matrix = c_matrix.astype('float') / c_matrix.sum(axis=1)[:, np.newaxis]
+ plt.figure()
+ plt.imshow(c_matrix, cmap=plt.cm.Blues)
+ plt.xlabel('Predicted')
+ plt.ylabel('True Label')
+ tick_marks = np.arange(len(classes))
+ plt.xticks(tick_marks, classes, rotation=45)
+ plt.yticks(tick_marks, classes)
+
+ fmt = '.2f'
+ thresh = c_matrix.max() / 2.
+ for i, j in itertools.product(range(c_matrix.shape[0]), range(c_matrix.shape[1])):
+ plt.text(j, i, format(c_matrix[i, j], fmt),
+ horizontalalignment="center",
+ color="white" if c_matrix[i, j] > thresh else "black")
+
+ buf = io.BytesIO()
+ plt.savefig(buf, dpi=500, format='png')
+ buf.seek(0)
+ image = tf.image.decode_png(buf.getvalue(), channels=4)
+ image = tf.expand_dims(image, 0)
+
+ summary_op = tf.summary.image("model_projections", image, max_outputs=1, family='family_name')
+ writer = tf.summary.FileWriter(log_dir)
+ writer.add_summary(summary_op.eval(session=K.get_session()), epoch)
+ writer.close()
+
+def multi_top_k_accuracy(self, log_dir, model, evaluation_generator, eval_steps, classes, epoch, ks=(1, 5)):
+ """Calculates top k accuracy for the given `k` values.
+
+ Args:
+ model: `KerasTPUModel` to evaluate.
+ evaluation_generator: a Python generator to generate (features, labels) for
+ evaluation.
+ eval_steps: int, number of evaluation steps.
+ ks: a tuple of int, position values to calculate top k accurary.
+
+ Returns:
+ A dictionary containing top k accuracy for the given `k` values.
+ """
+ def _count_matched(classes, predictions, labels, ks):
+ """Count number of pairs with label in any of top k predictions."""
+ top_k_matched = dict.fromkeys(ks, 0)
+ for prediction, label in zip(predictions, labels):
+ for k in ks:
+ top_k_predictions = np.argpartition(prediction, -k)[-k:]
+ if label in top_k_predictions:
+ top_k_matched[k] += 1
+
+ return top_k_matched
+
+ total = 0
+ top_k_matched = dict.fromkeys(ks, 0)
+ c_matrix = np.zeros((len(classes),len(classes)))
+ all_labels = np.zeros((0,1))
+ all_predictions = np.zeros((0,len(classes)))
+ logging.info('There are %d validation steps', eval_steps )
+ t = trange(eval_steps)
+ for step in t:
+ try:
+ (features, labels) = next(evaluation_generator)
+ except Exception as e:
+ logging.debug(e)
+ break
+ predictions = model.predict_on_batch(features) # May be quicker
+ # predictions = model.predict(features, batch_size=8)
+ sorted_pred_args = np.flip(predictions.argsort(axis=1), axis=1)
+ flat_predictions = sorted_pred_args[:,0]
+
+ # Todo: clean this function, it is a mess
+
+ # Print some falsely predicted images
+ if self._plot_wrong:# and not (step+4) % 8:
+ # Squeeze labels into same dimension and type as predictions
+ sq_labels = np.squeeze(labels.astype(int))
+ # If running in unity image per batch, squeeze squeezes one too many dimensions
+ if sq_labels.shape == ():
+ sq_labels = np.expand_dims(sq_labels, axis=0)
+ failed_indexes = np.where(np.not_equal(flat_predictions, sq_labels))[0]
+ limiter = 0
+ for idx in failed_indexes:
+ if limiter > 90:
+ break
+ limiter += 1
+ predicted_class_name = classes[flat_predictions[idx]]
+ true_class_name = classes[sq_labels[idx]]
+ proba_range = range(3) # Show softmax for top 3
+ top_cl = classes[sorted_pred_args[idx][proba_range]]
+ probas = predictions[idx][sorted_pred_args[idx][proba_range]]
+ if probas[0] > 0.9:
+ top_3 = '\n'.join(cl + ": " + proba for cl, proba in zip(top_cl, probas.astype(str)))
+ print("Predicted", flat_predictions[idx],
+ "True:", sq_labels[idx],
+ "Proba:", probas.astype(str))
+ plt.clf()
+ plt.imshow(features[idx].astype(int))
+ plt.text(0, 0, top_3, size=9, va="bottom", bbox=dict(boxstyle="square", ec=(1., 0.5, 0.5), fc=(1., 0.8, 0.8),))
+ plot_to_tensorboard(log_dir, epoch, "mislabled_images", "P_"+predicted_class_name+"_Tr_"+true_class_name)
+
+ c_matrix += confusion_matrix(labels, flat_predictions, labels=range(len(classes)))
+ batch_top_k_matched = _count_matched(classes, predictions, labels, ks)
+ all_labels = np.vstack((all_labels, labels))
+ all_predictions = np.vstack((all_predictions, predictions))
+ for k, matched in batch_top_k_matched.items():
+ top_k_matched[k] += matched
+ total += len(labels)
+
+ t.set_description("Top 1: %f" % np.float_(top_k_matched[1]/float(total)))
+
+ logging.info("Confusion matrix:")
+ print(c_matrix)
+
+ try:
+ #draw_c_matrix(self._log_dir, c_matrix, self._targets, epoch, normalize=True)
+ #draw_c_matrix(self._log_dir, c_matrix, self._targets, epoch, normalize=False)
+ draw_graphs(self, log_dir, classes, all_labels, all_predictions, epoch)
+ except:
+ pass
+
+ save_softmax(log_dir, epoch, all_labels, all_predictions)
+ metrics = dict([('top_{0}_accuracy'.format(k), np.float_(matched / float(total)))
+ for k, matched in top_k_matched.items()])
+ print(metrics)
+ return metrics
+
+
+class TensorBoardWithValidation(callbacks.TensorBoard):
+ """Extend TensorBoard Callback with validation .
+
+ Validation is executed at the end of specified epochs, and the validation
+ metrics are exported to tensorboard for visualization.
+
+ Args:
+ log_dir: the path of the directory where to save the log
+ files to be parsed by TensorBoard.
+ validation_imagenet_input: ImageNetInput for validation.
+ validation_steps: total number of steps to validate.
+ validation_epochs: a list of integers, epochs to run validation.
+ eval_top_k_accuracy: boolean, if true, evaluate top k accuracies using
+ multi_top_k_accuracy(). Otherwise, use model.evaluate().
+ N.B. enabling this would significantly slow down the eval time due to
+ using python generator for evaluation input.
+ top_ks: a tuple of int, position values to calculate top k accurary. It's
+ only used when eval_top_k_accuracy is true.
+ """
+
+ def __init__(self,
+ log_dir,
+ validation_imagenet_input,
+ validation_steps,
+ validation_epochs,
+ write_graph,
+ write_images,
+ plot_wrong,
+ plot_cm,
+ plot_pr,
+ classes,
+ complete_eval,
+ top_ks=(1, 5)):
+ super(TensorBoardWithValidation, self).__init__(log_dir)
+ self._validation_imagenet_input = validation_imagenet_input
+ self._validation_steps = validation_steps
+ self._validation_epochs = validation_epochs
+ self._write_graph = write_graph
+ self._write_images = write_images
+ self._plot_wrong = plot_wrong
+ self._plot_cm = plot_cm
+ self._plot_pr = plot_pr
+ self._complete_eval = complete_eval
+ self._top_ks = top_ks
+ self._targets = classes
+ self._log_dir = log_dir
+
+ def on_epoch_end(self, epoch, logs=None):
+ if epoch in self._validation_epochs:
+ logging.info('\nValidate in epoch %s', epoch)
+ if self._complete_eval:
+ logging.info("Running complete eval")
+ score = multi_top_k_accuracy(
+ self,
+ self._log_dir,
+ self.model,
+ self._validation_imagenet_input.evaluation_generator(
+ K.get_session()),
+ self._validation_steps,
+ self._targets,
+ epoch,
+ ks=self._top_ks)
+ for metric_name, metric_value in score.items():
+ logs['val_' + metric_name] = metric_value
+ else:
+ # evaluate() is executed as callbacks during the training. In this case,
+ # _numpy_to_infeed_manager_list is not empty, so save it for
+ # recovery at the end of evaluate call.
+ # TODO(jingli): remove this monkey patch hack once the fix is included
+ # in future TF release.
+ original_numpy_to_infeed_manager_list = []
+ if self.model._numpy_to_infeed_manager_list:
+ original_numpy_to_infeed_manager_list = (
+ self.model._numpy_to_infeed_manager_list)
+ self.model._numpy_to_infeed_manager_list = []
+ # Set _eval_function to None to enforce recompliation to use the newly
+ # created dataset in self._validation_imagenet_input.input_fn in
+ # evaluation.
+ # pylint: disable=bare-except
+ # pylint: disable=protected-access
+ try:
+ self.model._eval_function = None
+ except:
+ pass
+
+ try:
+ # In TF 1.12, _eval_function does not exist, only test_function
+ # existed.
+ self.model.test_function = None
+ except:
+ pass
+
+ scores = self.model.evaluate(self._validation_imagenet_input.input_fn,
+ steps=self._validation_steps)
+ self.model._numpy_to_infeed_manager_list = (
+ original_numpy_to_infeed_manager_list)
+ for metric_name, metric_value in zip(self.model.metrics_names, scores):
+ logging.info('Evaluation metric. %s: %s.', metric_name, metric_value)
+ logs['val_' + metric_name] = metric_value
+ # The parent callback is responsible to write the logs as events file.
+ super(TensorBoardWithValidation, self).on_epoch_end(epoch, logs)