diff options
Diffstat (limited to 'eval_utils.py')
-rw-r--r-- | eval_utils.py | 293 |
1 files changed, 293 insertions, 0 deletions
diff --git a/eval_utils.py b/eval_utils.py new file mode 100644 index 0000000..bf00b0b --- /dev/null +++ b/eval_utils.py @@ -0,0 +1,293 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Evaluation utils for `KerasTPUmodel`.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange +import sys + +import tensorflow as tf +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import optimizers +from tensorflow.python.keras import callbacks +from tensorflow.python.platform import tf_logging as logging +from sklearn.metrics import confusion_matrix + +import matplotlib.pyplot as plt +import io +import os +import itertools +import scikitplot as skplt + +from tqdm import trange + +def save_softmax(log_dir, epoch, labels, predictions): + location = os.path.join(log_dir, 'softmax' + str(epoch) + '.npz') + np.savez(location, labels=labels, predictions=predictions) + +def draw_graphs(self, log_dir, classes, y_true, y_probas, epoch): + y_pred = np.argmax(y_probas, axis=1) + if self._plot_cm: + skplt.metrics.plot_confusion_matrix(y_true, y_pred, normalize=True) + plot_to_tensorboard(log_dir, epoch, "model_projections", "confusion_matrix") + if self._plot_pr: + skplt.metrics.plot_precision_recall(y_true, y_probas) + plot_to_tensorboard(log_dir, epoch, "model_projections", "pr_curve") + +def plot_to_tensorboard(log_dir, epoch, model_projection, family_name): + buf = io.BytesIO() + plt.rcParams.update({'font.size': 5}) + plt.savefig(buf, dpi=250, format='png') + buf.seek(0) + image = tf.image.decode_png(buf.getvalue(), channels=3) + image = tf.expand_dims(image, 0) + + summary_op = tf.summary.image(model_projection, image, max_outputs=1, family=family_name) + writer = tf.summary.FileWriter(log_dir) + writer.add_summary(summary_op.eval(session=K.get_session()), epoch) + writer.close() + +def draw_c_matrix(log_dir, c_matrix, classes, epoch, normalize=False): + if normalize: + c_matrix = c_matrix.astype('float') / c_matrix.sum(axis=1)[:, np.newaxis] + plt.figure() + plt.imshow(c_matrix, cmap=plt.cm.Blues) + plt.xlabel('Predicted') + plt.ylabel('True Label') + tick_marks = np.arange(len(classes)) + plt.xticks(tick_marks, classes, rotation=45) + plt.yticks(tick_marks, classes) + + fmt = '.2f' + thresh = c_matrix.max() / 2. + for i, j in itertools.product(range(c_matrix.shape[0]), range(c_matrix.shape[1])): + plt.text(j, i, format(c_matrix[i, j], fmt), + horizontalalignment="center", + color="white" if c_matrix[i, j] > thresh else "black") + + buf = io.BytesIO() + plt.savefig(buf, dpi=500, format='png') + buf.seek(0) + image = tf.image.decode_png(buf.getvalue(), channels=4) + image = tf.expand_dims(image, 0) + + summary_op = tf.summary.image("model_projections", image, max_outputs=1, family='family_name') + writer = tf.summary.FileWriter(log_dir) + writer.add_summary(summary_op.eval(session=K.get_session()), epoch) + writer.close() + +def multi_top_k_accuracy(self, log_dir, model, evaluation_generator, eval_steps, classes, epoch, ks=(1, 5)): + """Calculates top k accuracy for the given `k` values. + + Args: + model: `KerasTPUModel` to evaluate. + evaluation_generator: a Python generator to generate (features, labels) for + evaluation. + eval_steps: int, number of evaluation steps. + ks: a tuple of int, position values to calculate top k accurary. + + Returns: + A dictionary containing top k accuracy for the given `k` values. + """ + def _count_matched(classes, predictions, labels, ks): + """Count number of pairs with label in any of top k predictions.""" + top_k_matched = dict.fromkeys(ks, 0) + for prediction, label in zip(predictions, labels): + for k in ks: + top_k_predictions = np.argpartition(prediction, -k)[-k:] + if label in top_k_predictions: + top_k_matched[k] += 1 + + return top_k_matched + + total = 0 + top_k_matched = dict.fromkeys(ks, 0) + c_matrix = np.zeros((len(classes),len(classes))) + all_labels = np.zeros((0,1)) + all_predictions = np.zeros((0,len(classes))) + logging.info('There are %d validation steps', eval_steps ) + t = trange(eval_steps) + for step in t: + try: + (features, labels) = next(evaluation_generator) + except Exception as e: + logging.debug(e) + break + predictions = model.predict_on_batch(features) # May be quicker + # predictions = model.predict(features, batch_size=8) + sorted_pred_args = np.flip(predictions.argsort(axis=1), axis=1) + flat_predictions = sorted_pred_args[:,0] + + # Todo: clean this function, it is a mess + + # Print some falsely predicted images + if self._plot_wrong:# and not (step+4) % 8: + # Squeeze labels into same dimension and type as predictions + sq_labels = np.squeeze(labels.astype(int)) + # If running in unity image per batch, squeeze squeezes one too many dimensions + if sq_labels.shape == (): + sq_labels = np.expand_dims(sq_labels, axis=0) + failed_indexes = np.where(np.not_equal(flat_predictions, sq_labels))[0] + limiter = 0 + for idx in failed_indexes: + if limiter > 90: + break + limiter += 1 + predicted_class_name = classes[flat_predictions[idx]] + true_class_name = classes[sq_labels[idx]] + proba_range = range(3) # Show softmax for top 3 + top_cl = classes[sorted_pred_args[idx][proba_range]] + probas = predictions[idx][sorted_pred_args[idx][proba_range]] + if probas[0] > 0.9: + top_3 = '\n'.join(cl + ": " + proba for cl, proba in zip(top_cl, probas.astype(str))) + print("Predicted", flat_predictions[idx], + "True:", sq_labels[idx], + "Proba:", probas.astype(str)) + plt.clf() + plt.imshow(features[idx].astype(int)) + plt.text(0, 0, top_3, size=9, va="bottom", bbox=dict(boxstyle="square", ec=(1., 0.5, 0.5), fc=(1., 0.8, 0.8),)) + plot_to_tensorboard(log_dir, epoch, "mislabled_images", "P_"+predicted_class_name+"_Tr_"+true_class_name) + + c_matrix += confusion_matrix(labels, flat_predictions, labels=range(len(classes))) + batch_top_k_matched = _count_matched(classes, predictions, labels, ks) + all_labels = np.vstack((all_labels, labels)) + all_predictions = np.vstack((all_predictions, predictions)) + for k, matched in batch_top_k_matched.items(): + top_k_matched[k] += matched + total += len(labels) + + t.set_description("Top 1: %f" % np.float_(top_k_matched[1]/float(total))) + + logging.info("Confusion matrix:") + print(c_matrix) + + try: + #draw_c_matrix(self._log_dir, c_matrix, self._targets, epoch, normalize=True) + #draw_c_matrix(self._log_dir, c_matrix, self._targets, epoch, normalize=False) + draw_graphs(self, log_dir, classes, all_labels, all_predictions, epoch) + except: + pass + + save_softmax(log_dir, epoch, all_labels, all_predictions) + metrics = dict([('top_{0}_accuracy'.format(k), np.float_(matched / float(total))) + for k, matched in top_k_matched.items()]) + print(metrics) + return metrics + + +class TensorBoardWithValidation(callbacks.TensorBoard): + """Extend TensorBoard Callback with validation . + + Validation is executed at the end of specified epochs, and the validation + metrics are exported to tensorboard for visualization. + + Args: + log_dir: the path of the directory where to save the log + files to be parsed by TensorBoard. + validation_imagenet_input: ImageNetInput for validation. + validation_steps: total number of steps to validate. + validation_epochs: a list of integers, epochs to run validation. + eval_top_k_accuracy: boolean, if true, evaluate top k accuracies using + multi_top_k_accuracy(). Otherwise, use model.evaluate(). + N.B. enabling this would significantly slow down the eval time due to + using python generator for evaluation input. + top_ks: a tuple of int, position values to calculate top k accurary. It's + only used when eval_top_k_accuracy is true. + """ + + def __init__(self, + log_dir, + validation_imagenet_input, + validation_steps, + validation_epochs, + write_graph, + write_images, + plot_wrong, + plot_cm, + plot_pr, + classes, + complete_eval, + top_ks=(1, 5)): + super(TensorBoardWithValidation, self).__init__(log_dir) + self._validation_imagenet_input = validation_imagenet_input + self._validation_steps = validation_steps + self._validation_epochs = validation_epochs + self._write_graph = write_graph + self._write_images = write_images + self._plot_wrong = plot_wrong + self._plot_cm = plot_cm + self._plot_pr = plot_pr + self._complete_eval = complete_eval + self._top_ks = top_ks + self._targets = classes + self._log_dir = log_dir + + def on_epoch_end(self, epoch, logs=None): + if epoch in self._validation_epochs: + logging.info('\nValidate in epoch %s', epoch) + if self._complete_eval: + logging.info("Running complete eval") + score = multi_top_k_accuracy( + self, + self._log_dir, + self.model, + self._validation_imagenet_input.evaluation_generator( + K.get_session()), + self._validation_steps, + self._targets, + epoch, + ks=self._top_ks) + for metric_name, metric_value in score.items(): + logs['val_' + metric_name] = metric_value + else: + # evaluate() is executed as callbacks during the training. In this case, + # _numpy_to_infeed_manager_list is not empty, so save it for + # recovery at the end of evaluate call. + # TODO(jingli): remove this monkey patch hack once the fix is included + # in future TF release. + original_numpy_to_infeed_manager_list = [] + if self.model._numpy_to_infeed_manager_list: + original_numpy_to_infeed_manager_list = ( + self.model._numpy_to_infeed_manager_list) + self.model._numpy_to_infeed_manager_list = [] + # Set _eval_function to None to enforce recompliation to use the newly + # created dataset in self._validation_imagenet_input.input_fn in + # evaluation. + # pylint: disable=bare-except + # pylint: disable=protected-access + try: + self.model._eval_function = None + except: + pass + + try: + # In TF 1.12, _eval_function does not exist, only test_function + # existed. + self.model.test_function = None + except: + pass + + scores = self.model.evaluate(self._validation_imagenet_input.input_fn, + steps=self._validation_steps) + self.model._numpy_to_infeed_manager_list = ( + original_numpy_to_infeed_manager_list) + for metric_name, metric_value in zip(self.model.metrics_names, scores): + logging.info('Evaluation metric. %s: %s.', metric_name, metric_value) + logs['val_' + metric_name] = metric_value + # The parent callback is responsible to write the logs as events file. + super(TensorBoardWithValidation, self).on_epoch_end(epoch, logs) |