#!/usr/bin/python import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import scikitplot as skplt from sklearn.preprocessing import label_binarize from sklearn.preprocessing import LabelEncoder from sklearn.metrics import auc from sklearn.metrics import precision_recall_curve from sklearn.metrics import average_precision_score flags = tf.app.flags flags.DEFINE_string('softmax', None, 'The softmax.npz file contained labels and probas') flags.DEFINE_string('dinfo', None, 'The dinfo.npz file') flags.DEFINE_integer('chunks', 4, 'The number of plots to produce') FLAGS = flags.FLAGS softmax = np.load(FLAGS.softmax) dinfo = np.load(FLAGS.dinfo) class_names=dinfo['classes'] y_true = softmax['labels'] y_proba = softmax['predictions'] def plot_precision_recall(y_true, y_probas, plot_micro=True, classes_to_plot=None, ax=None, figsize=None, cmap='nipy_spectral', text_fontsize="medium"): y_true = np.array(y_true) y_probas = np.array(y_probas) classes = np.unique(y_true) probas = y_probas if classes_to_plot is None: classes_to_plot = classes binarized_y_true = label_binarize(y_true, classes=classes) if len(classes) == 2: binarized_y_true = np.hstack( (1 - binarized_y_true, binarized_y_true)) fig, ax = plt.subplots(int(FLAGS.chunks/2), 2, figsize=figsize) chunk_size = int(len(classes)/FLAGS.chunks) + int(len(classes) % FLAGS.chunks > 0) print('Chunk size', chunk_size) indices_to_plot = np.in1d(classes, classes_to_plot) for i, img_class in enumerate(classes): average_precision = average_precision_score( binarized_y_true[:, i], probas[:, i]) precision, recall, _ = precision_recall_curve( y_true, probas[:, i], pos_label=img_class) color = plt.cm.get_cmap(cmap)(float(i%chunk_size) / chunk_size) ax[int(i/(chunk_size*2)), int(i%(chunk_size*2) > chunk_size)].plot(recall, precision, lw=2, label='{0} ' '(area = {1:0.3f})'.format(class_names[int(img_class)], average_precision), color=color) if plot_micro: precision, recall, _ = precision_recall_curve( binarized_y_true.ravel(), probas.ravel()) average_precision = average_precision_score(binarized_y_true, probas, average='micro') ax[int(FLAGS.chunks/2)-1,1].plot(recall, precision, label='micro-average PR ' '(area = {0:0.3f})'.format(average_precision), color='navy', linestyle=':', linewidth=4) for x in range(int(FLAGS.chunks/2)): for y in range(2): ax[x,y].set_xlim([0.0, 1.0]) ax[x,y].set_ylim([0.0, 1.05]) ax[x,y].set_xlabel('Recall') ax[x,y].set_ylabel('Precision') ax[x,y].tick_params(labelsize=text_fontsize) ax[x,y].legend(loc='lower left', fontsize=text_fontsize) return ax plot_precision_recall(y_true, y_proba, text_fontsize="xx-small", classes_to_plot=[3,16,41,70,77,82]) plt.show()