summaryrefslogtreecommitdiff
path: root/util/plot-report
blob: 927437fdb49fb7dc4f73b73947984b1a384fe3f0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/python
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import scikitplot as skplt

from sklearn.preprocessing import label_binarize
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import auc, confusion_matrix
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.metrics import classification_report

flags = tf.app.flags

flags.DEFINE_string('softmax', None, 'The softmax.npz file contained labels and probas')
flags.DEFINE_string('dinfo', None, 'The dinfo.npz file')
flags.DEFINE_integer('chunks', 4, 'The number of plots to produce')


FLAGS = flags.FLAGS


def plot_classification_report(classification_report, title='Classification report ', cmap='RdBu'):
    '''
    Plot scikit-learn classification report.
    Extension based on https://stackoverflow.com/a/31689645/395857 
    '''
    lines = classification_report.split('\n')

    classes = []
    plotMat = []
    support = []
    class_names = []
    for line in lines[2 : (len(lines) - 2)]:
        t = line.strip().split()
        if len(t) < 2: continue
        classes.append(t[0])
        v = [float(x) for x in t[1: len(t) - 1]]
        support.append(int(t[-1]))
        class_names.append(t[0])
        print(v)
        plotMat.append(v)

    print('plotMat: {0}'.format(plotMat))
    print('support: {0}'.format(support))

    xlabel = 'Metrics'
    ylabel = 'Classes'
    xticklabels = ['Precision', 'Recall', 'F1-score']
    yticklabels = ['{0} ({1})'.format(class_names[idx], sup) for idx, sup  in enumerate(support)]
    figure_width = 25
    figure_height = len(class_names) + 7
    correct_orientation = False
    heatmap(np.array(plotMat), title, xlabel, ylabel, xticklabels, yticklabels, figure_width, figure_height, correct_orientation, cmap=cmap)

softmax = np.load(FLAGS.softmax)
dinfo = np.load(FLAGS.dinfo)

class_names=dinfo['classes']

y_true = softmax['labels']
y_proba = softmax['predictions']

y_true_sparse = label_binarize(y_true, classes=np.unique(y_true))
y_pred = np.argmax(y_proba, axis=1)

cl_report= classification_report(y_true, y_pred, target_names=class_names, labels=np.arange(len(class_names)))
print(cl_report)

cm = confusion_matrix(y_true, y_pred, labels=np.arange(len(class_names)))
print(cm)

def top_wrong(cm, N=5):
    a=cm
    N = 150
    idx = np.argsort(a.ravel())[-N:][::-1] #single slicing: `[:N-2:-1]`
    topN_val = a.ravel()[idx]
    row_col = np.c_[np.unravel_index(idx, a.shape)]
    return row_col

#print(top_wrong(cm))
for idxs in top_wrong(cm):
    if idxs[0] != idxs[1]:
        print(class_names[idxs[0]],"\t",class_names[idxs[1]],"\t",cm[idxs[0], idxs[1]])

benign_class = np.where(class_names=='benign')

benign_pages, _ = np.where(y_true == benign_class)

cnt=0
cnt9=0
for benign_page in benign_pages:
    guess = y_pred[benign_page]
    if guess != benign_class:
        softmax_val = y_proba[benign_page][guess]
        cnt +=1
        if softmax_val > 0.95:
            print("B: " + class_names[guess] + "\t" + str(softmax_val))
            cnt9 += 1

print('We have ' + str(cnt9) + ' false-positives with softmax > 0.95 out of ' +str(cnt) + '/' + str(benign_pages.size))