diff options
author | nunzip <np.scarh@gmail.com> | 2019-02-06 02:31:01 +0000 |
---|---|---|
committer | nunzip <np.scarh@gmail.com> | 2019-02-06 02:31:01 +0000 |
commit | 6213982b229e42fa52a758d6d63c5b5351a39aec (patch) | |
tree | df340af5710fe6c1961f523ce7b49a3110798c9e | |
parent | 712860cd450eb34f8aafb21e8873e0740976aa82 (diff) | |
download | e4-vision-6213982b229e42fa52a758d6d63c5b5351a39aec.tar.gz e4-vision-6213982b229e42fa52a758d6d63c5b5351a39aec.tar.bz2 e4-vision-6213982b229e42fa52a758d6d63c5b5351a39aec.zip |
Add main, add more flags
-rwxr-xr-x | evaluate.py | 120 |
1 files changed, 85 insertions, 35 deletions
diff --git a/evaluate.py b/evaluate.py index a7a3ab7..b524ebd 100755 --- a/evaluate.py +++ b/evaluate.py @@ -3,16 +3,17 @@ # Vasil Zlatanov, Nunzio Pucci import numpy as np +import matplotlib import matplotlib.pyplot as plt import scikitplot as skplt import argparse -from timeit import default_timer as timer import logging from logging import debug from sklearn.cluster import KMeans from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomTreesEmbedding from sklearn.metrics import accuracy_score +import time parser = argparse.ArgumentParser() parser.add_argument("-d", "--data", help="Data path", action='store_true', default='data.npz') @@ -20,30 +21,18 @@ parser.add_argument("-c", "--conf_mat", help="Show visual confusion matrix", act parser.add_argument("-k", "--kmean", help="Perform kmean clustering with --kmean cluster centers", type=int, default=0) parser.add_argument("-l", "--leaves", help="Maximum leaf nodes for RF classifier", type=int, default=256) parser.add_argument("-e", "--estimators", help="number of estimators to be used", type=int, default=100) +parser.add_argument("-D", "--treedepth", help="depth of trees", type=int, default=5) parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true') +parser.add_argument("-t", "--timer", help="Display execution time", action='store_true') +parser.add_argument("-T", "--testmode", help="Testmode", action='store_true') args = parser.parse_args() if args.verbose: logging.basicConfig(level=logging.DEBUG) -data = np.load(args.data) -train = data['desc_tr'] -test = data['desc_te'] -train_part = data['desc_sel'].T -logging.debug("Verbose is on") - -if (args.kmean): - logging.debug("Computing KMeans with", train_part.shape[0], "keywords") - kmeans = KMeans(n_clusters=args.kmean, n_init=args.estimators, random_state=0).fit(train_part) -else: - trees = RandomTreesEmbedding(max_leaf_nodes=args.leaves, n_estimators=args.estimators, random_state=0).fit(train_part) - - -logging.debug("Generating histograms") - -def make_histogram(data): +def make_histogram(data, model, args): if args.kmean: - hist_size = args.estimators*args.kmean + hist_size = args.kmean else: hist_size = args.estimators*args.leaves @@ -51,30 +40,91 @@ def make_histogram(data): for i in range(data.shape[0]): for j in range(data.shape[1]): if (args.kmean): - histogram[i][j] = np.bincount(kmeans.predict(data[i][j].T),minlength=args.kmean) + histogram[i][j] = np.bincount(model.predict(data[i][j].T), minlength=args.kmean) else: - leaves = trees.apply(data[i][j].T) + leaves = model.apply(data[i][j].T) leaves = np.apply_along_axis(np.bincount, axis=0, arr=leaves, minlength=args.leaves) histogram[i][j] = leaves.reshape(hist_size) return histogram -hist_train = make_histogram(train) -hist_test = make_histogram(test) +def run_model (data, train, test, train_part, args): + if args.timer: + start = time.time() + + if (args.kmean): + logging.debug("Computing KMeans with", train_part.shape[0], "keywords") + kmeans = KMeans(n_clusters=args.kmean, n_init=args.estimators, random_state=0).fit(train_part) + hist_train = make_histogram(train, kmeans, args) + hist_test = make_histogram(test, kmeans, args) + else: + trees = RandomTreesEmbedding(max_leaf_nodes=args.leaves, max_depth=args.treedepth, n_estimators=args.estimators, random_state=0).fit(train_part) + hist_train = make_histogram(train, trees, args) + hist_test = make_histogram(test, trees, args) + + logging.debug("Generating histograms") + + + logging.debug("Keywords shape", hist_train.shape, "\n") + logging.debug("Planting trees...") + clf = RandomForestClassifier() + clf.fit( + hist_train.reshape((hist_train.shape[0]*hist_train.shape[1], hist_train.shape[2])), + np.repeat(np.arange(hist_train.shape[0]), hist_train.shape[1])) -logging.debug("Keywords shape", hist_train.shape, "\n") -logging.debug("Planting trees...") -clf = RandomForestClassifier() -clf.fit( - hist_train.reshape((hist_train.shape[0]*hist_train.shape[1], hist_train.shape[2])), - np.repeat(np.arange(hist_train.shape[0]), hist_train.shape[1])) + logging.debug("Random forests created") -logging.debug("Random forests created") + test_pred = clf.predict(hist_test.reshape((hist_test.shape[0]*hist_test.shape[1], hist_test.shape[2]))) + test_label = np.repeat(np.arange(hist_test.shape[0]), hist_test.shape[1]) -test_pred = clf.predict(hist_test.reshape((hist_test.shape[0]*hist_test.shape[1], hist_test.shape[2]))) -test_label = np.repeat(np.arange(hist_test.shape[0]), hist_test.shape[1]) + if args.timer: + end = time.time() + print("Execution time: ",end - start) + if args.conf_mat: + skplt.metrics.plot_confusion_matrix(test_pred, test_label, normalize=True) + plt.show() -print(accuracy_score(test_pred, test_label)) + return accuracy_score(test_pred, test_label) + + + +def main(): + data = np.load(args.data) + train = data['desc_tr'] + test = data['desc_te'] + train_part = data['desc_sel'].T + logging.debug("Verbose is on") + + if args.testmode: + args.leaves = 10000 + cnt = 0 + acc = np.zeros((5,5)) + for i in range(5): + args.estimators = (i+1)*200 + cnt+=1 + for j in range(5): + args.treedepth = j + 1 + print("Step ", cnt) + acc[i][j] = run_model (data, train, test, train_part, args) + print("Accuracy ",acc[i][j]) + cnt+=1 + fig, ax = plt.subplots() + im = ax.imshow(acc) + ax.set_xticks(np.arange(5)) + ax.set_yticks(np.arange(5)) + ax.set_xlabel('Number of trees') + ax.set_ylabel('Tree depth') + + # Loop over data dimensions and create text annotations. + for i in range(5): + for j in range(5): + text = ax.text(j, i, acc[i, j], ha="center", va="center", color="w") + ax.set_title("Accuracy varying hyper-parameters") + fig.tight_layout() + plt.show() + else: + acc = run_model (data, train, test, train_part, args) + print(acc) + +if __name__ == "__main__": + main() -if args.conf_mat: - skplt.metrics.plot_confusion_matrix(test_pred, test_label, normalize=True) - plt.show() |