diff options
Diffstat (limited to 'util/plot-softmax')
-rwxr-xr-x | util/plot-softmax | 94 |
1 files changed, 94 insertions, 0 deletions
diff --git a/util/plot-softmax b/util/plot-softmax new file mode 100755 index 0000000..c6c2774 --- /dev/null +++ b/util/plot-softmax @@ -0,0 +1,94 @@ +#!/usr/bin/python +import tensorflow as tf +import numpy as np +import matplotlib.pyplot as plt +import scikitplot as skplt + +from sklearn.preprocessing import label_binarize +from sklearn.preprocessing import LabelEncoder +from sklearn.metrics import auc +from sklearn.metrics import precision_recall_curve +from sklearn.metrics import average_precision_score + +flags = tf.app.flags + +flags.DEFINE_string('softmax', None, 'The softmax.npz file contained labels and probas') +flags.DEFINE_string('dinfo', None, 'The dinfo.npz file') +flags.DEFINE_integer('chunks', 4, 'The number of plots to produce') + + +FLAGS = flags.FLAGS + +softmax = np.load(FLAGS.softmax) +dinfo = np.load(FLAGS.dinfo) + +class_names=dinfo['classes'] + +y_true = softmax['labels'] +y_proba = softmax['predictions'] + + +def plot_precision_recall(y_true, y_probas, + plot_micro=True, + classes_to_plot=None, ax=None, + figsize=None, cmap='nipy_spectral', + text_fontsize="medium"): + + y_true = np.array(y_true) + y_probas = np.array(y_probas) + + classes = np.unique(y_true) + probas = y_probas + + if classes_to_plot is None: + classes_to_plot = classes + + binarized_y_true = label_binarize(y_true, classes=classes) + if len(classes) == 2: + binarized_y_true = np.hstack( + (1 - binarized_y_true, binarized_y_true)) + + fig, ax = plt.subplots(int(FLAGS.chunks/2), 2, figsize=figsize) + chunk_size = int(len(classes)/FLAGS.chunks) + int(len(classes) % FLAGS.chunks > 0) + print('Chunk size', chunk_size) + + + + indices_to_plot = np.in1d(classes, classes_to_plot) + + for i, img_class in enumerate(classes): + average_precision = average_precision_score( + binarized_y_true[:, i], + probas[:, i]) + precision, recall, _ = precision_recall_curve( + y_true, probas[:, i], pos_label=img_class) + color = plt.cm.get_cmap(cmap)(float(i%chunk_size) / chunk_size) + ax[int(i/(chunk_size*2)), int(i%(chunk_size*2) > chunk_size)].plot(recall, precision, lw=2, + label='{0} ' + '(area = {1:0.3f})'.format(class_names[int(img_class)], + average_precision), + color=color) + + if plot_micro: + precision, recall, _ = precision_recall_curve( + binarized_y_true.ravel(), probas.ravel()) + average_precision = average_precision_score(binarized_y_true, + probas, + average='micro') + ax[int(FLAGS.chunks/2)-1,1].plot(recall, precision, + label='micro-average PR ' + '(area = {0:0.3f})'.format(average_precision), + color='navy', linestyle=':', linewidth=4) + + for x in range(int(FLAGS.chunks/2)): + for y in range(2): + ax[x,y].set_xlim([0.0, 1.0]) + ax[x,y].set_ylim([0.0, 1.05]) + ax[x,y].set_xlabel('Recall') + ax[x,y].set_ylabel('Precision') + ax[x,y].tick_params(labelsize=text_fontsize) + ax[x,y].legend(loc='lower left', fontsize=text_fontsize) + return ax + +plot_precision_recall(y_true, y_proba, text_fontsize="xx-small", classes_to_plot=[3,16,41,70,77,82]) +plt.show() |