summaryrefslogtreecommitdiff
path: root/util/plot-report
diff options
context:
space:
mode:
Diffstat (limited to 'util/plot-report')
-rwxr-xr-xutil/plot-report102
1 files changed, 102 insertions, 0 deletions
diff --git a/util/plot-report b/util/plot-report
new file mode 100755
index 0000000..927437f
--- /dev/null
+++ b/util/plot-report
@@ -0,0 +1,102 @@
+#!/usr/bin/python
+import tensorflow as tf
+import numpy as np
+import matplotlib.pyplot as plt
+import scikitplot as skplt
+
+from sklearn.preprocessing import label_binarize
+from sklearn.preprocessing import LabelEncoder
+from sklearn.metrics import auc, confusion_matrix
+from sklearn.metrics import precision_recall_curve
+from sklearn.metrics import average_precision_score
+from sklearn.metrics import classification_report
+
+flags = tf.app.flags
+
+flags.DEFINE_string('softmax', None, 'The softmax.npz file contained labels and probas')
+flags.DEFINE_string('dinfo', None, 'The dinfo.npz file')
+flags.DEFINE_integer('chunks', 4, 'The number of plots to produce')
+
+
+FLAGS = flags.FLAGS
+
+
+def plot_classification_report(classification_report, title='Classification report ', cmap='RdBu'):
+ '''
+ Plot scikit-learn classification report.
+ Extension based on https://stackoverflow.com/a/31689645/395857
+ '''
+ lines = classification_report.split('\n')
+
+ classes = []
+ plotMat = []
+ support = []
+ class_names = []
+ for line in lines[2 : (len(lines) - 2)]:
+ t = line.strip().split()
+ if len(t) < 2: continue
+ classes.append(t[0])
+ v = [float(x) for x in t[1: len(t) - 1]]
+ support.append(int(t[-1]))
+ class_names.append(t[0])
+ print(v)
+ plotMat.append(v)
+
+ print('plotMat: {0}'.format(plotMat))
+ print('support: {0}'.format(support))
+
+ xlabel = 'Metrics'
+ ylabel = 'Classes'
+ xticklabels = ['Precision', 'Recall', 'F1-score']
+ yticklabels = ['{0} ({1})'.format(class_names[idx], sup) for idx, sup in enumerate(support)]
+ figure_width = 25
+ figure_height = len(class_names) + 7
+ correct_orientation = False
+ heatmap(np.array(plotMat), title, xlabel, ylabel, xticklabels, yticklabels, figure_width, figure_height, correct_orientation, cmap=cmap)
+
+softmax = np.load(FLAGS.softmax)
+dinfo = np.load(FLAGS.dinfo)
+
+class_names=dinfo['classes']
+
+y_true = softmax['labels']
+y_proba = softmax['predictions']
+
+y_true_sparse = label_binarize(y_true, classes=np.unique(y_true))
+y_pred = np.argmax(y_proba, axis=1)
+
+cl_report= classification_report(y_true, y_pred, target_names=class_names, labels=np.arange(len(class_names)))
+print(cl_report)
+
+cm = confusion_matrix(y_true, y_pred, labels=np.arange(len(class_names)))
+print(cm)
+
+def top_wrong(cm, N=5):
+ a=cm
+ N = 150
+ idx = np.argsort(a.ravel())[-N:][::-1] #single slicing: `[:N-2:-1]`
+ topN_val = a.ravel()[idx]
+ row_col = np.c_[np.unravel_index(idx, a.shape)]
+ return row_col
+
+#print(top_wrong(cm))
+for idxs in top_wrong(cm):
+ if idxs[0] != idxs[1]:
+ print(class_names[idxs[0]],"\t",class_names[idxs[1]],"\t",cm[idxs[0], idxs[1]])
+
+benign_class = np.where(class_names=='benign')
+
+benign_pages, _ = np.where(y_true == benign_class)
+
+cnt=0
+cnt9=0
+for benign_page in benign_pages:
+ guess = y_pred[benign_page]
+ if guess != benign_class:
+ softmax_val = y_proba[benign_page][guess]
+ cnt +=1
+ if softmax_val > 0.95:
+ print("B: " + class_names[guess] + "\t" + str(softmax_val))
+ cnt9 += 1
+
+print('We have ' + str(cnt9) + ' false-positives with softmax > 0.95 out of ' +str(cnt) + '/' + str(benign_pages.size))