diff options
Diffstat (limited to 'util/plot-report')
-rwxr-xr-x | util/plot-report | 102 |
1 files changed, 102 insertions, 0 deletions
diff --git a/util/plot-report b/util/plot-report new file mode 100755 index 0000000..927437f --- /dev/null +++ b/util/plot-report @@ -0,0 +1,102 @@ +#!/usr/bin/python +import tensorflow as tf +import numpy as np +import matplotlib.pyplot as plt +import scikitplot as skplt + +from sklearn.preprocessing import label_binarize +from sklearn.preprocessing import LabelEncoder +from sklearn.metrics import auc, confusion_matrix +from sklearn.metrics import precision_recall_curve +from sklearn.metrics import average_precision_score +from sklearn.metrics import classification_report + +flags = tf.app.flags + +flags.DEFINE_string('softmax', None, 'The softmax.npz file contained labels and probas') +flags.DEFINE_string('dinfo', None, 'The dinfo.npz file') +flags.DEFINE_integer('chunks', 4, 'The number of plots to produce') + + +FLAGS = flags.FLAGS + + +def plot_classification_report(classification_report, title='Classification report ', cmap='RdBu'): + ''' + Plot scikit-learn classification report. + Extension based on https://stackoverflow.com/a/31689645/395857 + ''' + lines = classification_report.split('\n') + + classes = [] + plotMat = [] + support = [] + class_names = [] + for line in lines[2 : (len(lines) - 2)]: + t = line.strip().split() + if len(t) < 2: continue + classes.append(t[0]) + v = [float(x) for x in t[1: len(t) - 1]] + support.append(int(t[-1])) + class_names.append(t[0]) + print(v) + plotMat.append(v) + + print('plotMat: {0}'.format(plotMat)) + print('support: {0}'.format(support)) + + xlabel = 'Metrics' + ylabel = 'Classes' + xticklabels = ['Precision', 'Recall', 'F1-score'] + yticklabels = ['{0} ({1})'.format(class_names[idx], sup) for idx, sup in enumerate(support)] + figure_width = 25 + figure_height = len(class_names) + 7 + correct_orientation = False + heatmap(np.array(plotMat), title, xlabel, ylabel, xticklabels, yticklabels, figure_width, figure_height, correct_orientation, cmap=cmap) + +softmax = np.load(FLAGS.softmax) +dinfo = np.load(FLAGS.dinfo) + +class_names=dinfo['classes'] + +y_true = softmax['labels'] +y_proba = softmax['predictions'] + +y_true_sparse = label_binarize(y_true, classes=np.unique(y_true)) +y_pred = np.argmax(y_proba, axis=1) + +cl_report= classification_report(y_true, y_pred, target_names=class_names, labels=np.arange(len(class_names))) +print(cl_report) + +cm = confusion_matrix(y_true, y_pred, labels=np.arange(len(class_names))) +print(cm) + +def top_wrong(cm, N=5): + a=cm + N = 150 + idx = np.argsort(a.ravel())[-N:][::-1] #single slicing: `[:N-2:-1]` + topN_val = a.ravel()[idx] + row_col = np.c_[np.unravel_index(idx, a.shape)] + return row_col + +#print(top_wrong(cm)) +for idxs in top_wrong(cm): + if idxs[0] != idxs[1]: + print(class_names[idxs[0]],"\t",class_names[idxs[1]],"\t",cm[idxs[0], idxs[1]]) + +benign_class = np.where(class_names=='benign') + +benign_pages, _ = np.where(y_true == benign_class) + +cnt=0 +cnt9=0 +for benign_page in benign_pages: + guess = y_pred[benign_page] + if guess != benign_class: + softmax_val = y_proba[benign_page][guess] + cnt +=1 + if softmax_val > 0.95: + print("B: " + class_names[guess] + "\t" + str(softmax_val)) + cnt9 += 1 + +print('We have ' + str(cnt9) + ' false-positives with softmax > 0.95 out of ' +str(cnt) + '/' + str(benign_pages.size)) |