summaryrefslogtreecommitdiff
path: root/util/plot-softmax
diff options
context:
space:
mode:
Diffstat (limited to 'util/plot-softmax')
-rwxr-xr-xutil/plot-softmax94
1 files changed, 94 insertions, 0 deletions
diff --git a/util/plot-softmax b/util/plot-softmax
new file mode 100755
index 0000000..c6c2774
--- /dev/null
+++ b/util/plot-softmax
@@ -0,0 +1,94 @@
+#!/usr/bin/python
+import tensorflow as tf
+import numpy as np
+import matplotlib.pyplot as plt
+import scikitplot as skplt
+
+from sklearn.preprocessing import label_binarize
+from sklearn.preprocessing import LabelEncoder
+from sklearn.metrics import auc
+from sklearn.metrics import precision_recall_curve
+from sklearn.metrics import average_precision_score
+
+flags = tf.app.flags
+
+flags.DEFINE_string('softmax', None, 'The softmax.npz file contained labels and probas')
+flags.DEFINE_string('dinfo', None, 'The dinfo.npz file')
+flags.DEFINE_integer('chunks', 4, 'The number of plots to produce')
+
+
+FLAGS = flags.FLAGS
+
+softmax = np.load(FLAGS.softmax)
+dinfo = np.load(FLAGS.dinfo)
+
+class_names=dinfo['classes']
+
+y_true = softmax['labels']
+y_proba = softmax['predictions']
+
+
+def plot_precision_recall(y_true, y_probas,
+ plot_micro=True,
+ classes_to_plot=None, ax=None,
+ figsize=None, cmap='nipy_spectral',
+ text_fontsize="medium"):
+
+ y_true = np.array(y_true)
+ y_probas = np.array(y_probas)
+
+ classes = np.unique(y_true)
+ probas = y_probas
+
+ if classes_to_plot is None:
+ classes_to_plot = classes
+
+ binarized_y_true = label_binarize(y_true, classes=classes)
+ if len(classes) == 2:
+ binarized_y_true = np.hstack(
+ (1 - binarized_y_true, binarized_y_true))
+
+ fig, ax = plt.subplots(int(FLAGS.chunks/2), 2, figsize=figsize)
+ chunk_size = int(len(classes)/FLAGS.chunks) + int(len(classes) % FLAGS.chunks > 0)
+ print('Chunk size', chunk_size)
+
+
+
+ indices_to_plot = np.in1d(classes, classes_to_plot)
+
+ for i, img_class in enumerate(classes):
+ average_precision = average_precision_score(
+ binarized_y_true[:, i],
+ probas[:, i])
+ precision, recall, _ = precision_recall_curve(
+ y_true, probas[:, i], pos_label=img_class)
+ color = plt.cm.get_cmap(cmap)(float(i%chunk_size) / chunk_size)
+ ax[int(i/(chunk_size*2)), int(i%(chunk_size*2) > chunk_size)].plot(recall, precision, lw=2,
+ label='{0} '
+ '(area = {1:0.3f})'.format(class_names[int(img_class)],
+ average_precision),
+ color=color)
+
+ if plot_micro:
+ precision, recall, _ = precision_recall_curve(
+ binarized_y_true.ravel(), probas.ravel())
+ average_precision = average_precision_score(binarized_y_true,
+ probas,
+ average='micro')
+ ax[int(FLAGS.chunks/2)-1,1].plot(recall, precision,
+ label='micro-average PR '
+ '(area = {0:0.3f})'.format(average_precision),
+ color='navy', linestyle=':', linewidth=4)
+
+ for x in range(int(FLAGS.chunks/2)):
+ for y in range(2):
+ ax[x,y].set_xlim([0.0, 1.0])
+ ax[x,y].set_ylim([0.0, 1.05])
+ ax[x,y].set_xlabel('Recall')
+ ax[x,y].set_ylabel('Precision')
+ ax[x,y].tick_params(labelsize=text_fontsize)
+ ax[x,y].legend(loc='lower left', fontsize=text_fontsize)
+ return ax
+
+plot_precision_recall(y_true, y_proba, text_fontsize="xx-small", classes_to_plot=[3,16,41,70,77,82])
+plt.show()