# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Evaluation utils for `KerasTPUmodel`.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np from six.moves import xrange import sys import tensorflow as tf from tensorflow.python.keras import backend as K from tensorflow.python.keras import optimizers from tensorflow.python.keras import callbacks from tensorflow.python.platform import tf_logging as logging from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt import io import os import itertools import scikitplot as skplt from tqdm import trange def save_softmax(log_dir, epoch, labels, predictions): location = os.path.join(log_dir, 'softmax' + str(epoch) + '.npz') np.savez(location, labels=labels, predictions=predictions) def draw_graphs(self, log_dir, classes, y_true, y_probas, epoch): y_pred = np.argmax(y_probas, axis=1) if self._plot_cm: skplt.metrics.plot_confusion_matrix(y_true, y_pred, normalize=True) plot_to_tensorboard(log_dir, epoch, "model_projections", "confusion_matrix") if self._plot_pr: skplt.metrics.plot_precision_recall(y_true, y_probas) plot_to_tensorboard(log_dir, epoch, "model_projections", "pr_curve") def plot_to_tensorboard(log_dir, epoch, model_projection, family_name): buf = io.BytesIO() plt.rcParams.update({'font.size': 5}) plt.savefig(buf, dpi=250, format='png') buf.seek(0) image = tf.image.decode_png(buf.getvalue(), channels=3) image = tf.expand_dims(image, 0) summary_op = tf.summary.image(model_projection, image, max_outputs=1, family=family_name) writer = tf.summary.FileWriter(log_dir) writer.add_summary(summary_op.eval(session=K.get_session()), epoch) writer.close() def draw_c_matrix(log_dir, c_matrix, classes, epoch, normalize=False): if normalize: c_matrix = c_matrix.astype('float') / c_matrix.sum(axis=1)[:, np.newaxis] plt.figure() plt.imshow(c_matrix, cmap=plt.cm.Blues) plt.xlabel('Predicted') plt.ylabel('True Label') tick_marks = np.arange(len(classes)) plt.xticks(tick_marks, classes, rotation=45) plt.yticks(tick_marks, classes) fmt = '.2f' thresh = c_matrix.max() / 2. for i, j in itertools.product(range(c_matrix.shape[0]), range(c_matrix.shape[1])): plt.text(j, i, format(c_matrix[i, j], fmt), horizontalalignment="center", color="white" if c_matrix[i, j] > thresh else "black") buf = io.BytesIO() plt.savefig(buf, dpi=500, format='png') buf.seek(0) image = tf.image.decode_png(buf.getvalue(), channels=4) image = tf.expand_dims(image, 0) summary_op = tf.summary.image("model_projections", image, max_outputs=1, family='family_name') writer = tf.summary.FileWriter(log_dir) writer.add_summary(summary_op.eval(session=K.get_session()), epoch) writer.close() def multi_top_k_accuracy(self, log_dir, model, evaluation_generator, eval_steps, classes, epoch, ks=(1, 5)): """Calculates top k accuracy for the given `k` values. Args: model: `KerasTPUModel` to evaluate. evaluation_generator: a Python generator to generate (features, labels) for evaluation. eval_steps: int, number of evaluation steps. ks: a tuple of int, position values to calculate top k accurary. Returns: A dictionary containing top k accuracy for the given `k` values. """ def _count_matched(classes, predictions, labels, ks): """Count number of pairs with label in any of top k predictions.""" top_k_matched = dict.fromkeys(ks, 0) for prediction, label in zip(predictions, labels): for k in ks: top_k_predictions = np.argpartition(prediction, -k)[-k:] if label in top_k_predictions: top_k_matched[k] += 1 return top_k_matched total = 0 top_k_matched = dict.fromkeys(ks, 0) c_matrix = np.zeros((len(classes),len(classes))) all_labels = np.zeros((0,1)) all_predictions = np.zeros((0,len(classes))) logging.info('There are %d validation steps', eval_steps ) t = trange(eval_steps) for step in t: try: (features, labels) = next(evaluation_generator) except Exception as e: logging.debug(e) break predictions = model.predict_on_batch(features) # May be quicker # predictions = model.predict(features, batch_size=8) sorted_pred_args = np.flip(predictions.argsort(axis=1), axis=1) flat_predictions = sorted_pred_args[:,0] # Todo: clean this function, it is a mess # Print some falsely predicted images if self._plot_wrong:# and not (step+4) % 8: # Squeeze labels into same dimension and type as predictions sq_labels = np.squeeze(labels.astype(int)) # If running in unity image per batch, squeeze squeezes one too many dimensions if sq_labels.shape == (): sq_labels = np.expand_dims(sq_labels, axis=0) failed_indexes = np.where(np.not_equal(flat_predictions, sq_labels))[0] limiter = 0 for idx in failed_indexes: if limiter > 90: break limiter += 1 predicted_class_name = classes[flat_predictions[idx]] true_class_name = classes[sq_labels[idx]] proba_range = range(3) # Show softmax for top 3 top_cl = classes[sorted_pred_args[idx][proba_range]] probas = predictions[idx][sorted_pred_args[idx][proba_range]] if probas[0] > 0.9: top_3 = '\n'.join(cl + ": " + proba for cl, proba in zip(top_cl, probas.astype(str))) print("Predicted", flat_predictions[idx], "True:", sq_labels[idx], "Proba:", probas.astype(str)) plt.clf() plt.imshow(features[idx].astype(int)) plt.text(0, 0, top_3, size=9, va="bottom", bbox=dict(boxstyle="square", ec=(1., 0.5, 0.5), fc=(1., 0.8, 0.8),)) plot_to_tensorboard(log_dir, epoch, "mislabled_images", "P_"+predicted_class_name+"_Tr_"+true_class_name) c_matrix += confusion_matrix(labels, flat_predictions, labels=range(len(classes))) batch_top_k_matched = _count_matched(classes, predictions, labels, ks) all_labels = np.vstack((all_labels, labels)) all_predictions = np.vstack((all_predictions, predictions)) for k, matched in batch_top_k_matched.items(): top_k_matched[k] += matched total += len(labels) t.set_description("Top 1: %f" % np.float_(top_k_matched[1]/float(total))) logging.info("Confusion matrix:") print(c_matrix) try: #draw_c_matrix(self._log_dir, c_matrix, self._targets, epoch, normalize=True) #draw_c_matrix(self._log_dir, c_matrix, self._targets, epoch, normalize=False) draw_graphs(self, log_dir, classes, all_labels, all_predictions, epoch) except: pass save_softmax(log_dir, epoch, all_labels, all_predictions) metrics = dict([('top_{0}_accuracy'.format(k), np.float_(matched / float(total))) for k, matched in top_k_matched.items()]) print(metrics) return metrics class TensorBoardWithValidation(callbacks.TensorBoard): """Extend TensorBoard Callback with validation . Validation is executed at the end of specified epochs, and the validation metrics are exported to tensorboard for visualization. Args: log_dir: the path of the directory where to save the log files to be parsed by TensorBoard. validation_imagenet_input: ImageNetInput for validation. validation_steps: total number of steps to validate. validation_epochs: a list of integers, epochs to run validation. eval_top_k_accuracy: boolean, if true, evaluate top k accuracies using multi_top_k_accuracy(). Otherwise, use model.evaluate(). N.B. enabling this would significantly slow down the eval time due to using python generator for evaluation input. top_ks: a tuple of int, position values to calculate top k accurary. It's only used when eval_top_k_accuracy is true. """ def __init__(self, log_dir, validation_imagenet_input, validation_steps, validation_epochs, write_graph, write_images, plot_wrong, plot_cm, plot_pr, classes, complete_eval, top_ks=(1, 5)): super(TensorBoardWithValidation, self).__init__(log_dir) self._validation_imagenet_input = validation_imagenet_input self._validation_steps = validation_steps self._validation_epochs = validation_epochs self._write_graph = write_graph self._write_images = write_images self._plot_wrong = plot_wrong self._plot_cm = plot_cm self._plot_pr = plot_pr self._complete_eval = complete_eval self._top_ks = top_ks self._targets = classes self._log_dir = log_dir def on_epoch_end(self, epoch, logs=None): if epoch in self._validation_epochs: logging.info('\nValidate in epoch %s', epoch) if self._complete_eval: logging.info("Running complete eval") score = multi_top_k_accuracy( self, self._log_dir, self.model, self._validation_imagenet_input.evaluation_generator( K.get_session()), self._validation_steps, self._targets, epoch, ks=self._top_ks) for metric_name, metric_value in score.items(): logs['val_' + metric_name] = metric_value else: # evaluate() is executed as callbacks during the training. In this case, # _numpy_to_infeed_manager_list is not empty, so save it for # recovery at the end of evaluate call. # TODO(jingli): remove this monkey patch hack once the fix is included # in future TF release. original_numpy_to_infeed_manager_list = [] if self.model._numpy_to_infeed_manager_list: original_numpy_to_infeed_manager_list = ( self.model._numpy_to_infeed_manager_list) self.model._numpy_to_infeed_manager_list = [] # Set _eval_function to None to enforce recompliation to use the newly # created dataset in self._validation_imagenet_input.input_fn in # evaluation. # pylint: disable=bare-except # pylint: disable=protected-access try: self.model._eval_function = None except: pass try: # In TF 1.12, _eval_function does not exist, only test_function # existed. self.model.test_function = None except: pass scores = self.model.evaluate(self._validation_imagenet_input.input_fn, steps=self._validation_steps) self.model._numpy_to_infeed_manager_list = ( original_numpy_to_infeed_manager_list) for metric_name, metric_value in zip(self.model.metrics_names, scores): logging.info('Evaluation metric. %s: %s.', metric_name, metric_value) logs['val_' + metric_name] = metric_value # The parent callback is responsible to write the logs as events file. super(TensorBoardWithValidation, self).on_epoch_end(epoch, logs)