From 712860cd450eb34f8aafb21e8873e0740976aa82 Mon Sep 17 00:00:00 2001 From: nunzip Date: Mon, 4 Feb 2019 22:43:54 +0000 Subject: Add flags --- evaluate.py | 67 +++++++++++++++++++++++++++++++++++++------------------------ 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/evaluate.py b/evaluate.py index 2de8f8c..a7a3ab7 100755 --- a/evaluate.py +++ b/evaluate.py @@ -2,64 +2,79 @@ # EE4 Selected Topics From Computer Vision Coursework # Vasil Zlatanov, Nunzio Pucci -DATA_FILE = 'data.npz' -CLUSTER_CNT = 256 -KMEANS = False - -if KMEANS: - N_ESTIMATORS = 1 -else: - N_ESTIMATORS = 100 - import numpy as np import matplotlib.pyplot as plt import scikitplot as skplt - +import argparse +from timeit import default_timer as timer +import logging +from logging import debug from sklearn.cluster import KMeans from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomTreesEmbedding +from sklearn.metrics import accuracy_score -data = np.load(DATA_FILE) +parser = argparse.ArgumentParser() +parser.add_argument("-d", "--data", help="Data path", action='store_true', default='data.npz') +parser.add_argument("-c", "--conf_mat", help="Show visual confusion matrix", action='store_true') +parser.add_argument("-k", "--kmean", help="Perform kmean clustering with --kmean cluster centers", type=int, default=0) +parser.add_argument("-l", "--leaves", help="Maximum leaf nodes for RF classifier", type=int, default=256) +parser.add_argument("-e", "--estimators", help="number of estimators to be used", type=int, default=100) +parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true') +args = parser.parse_args() +if args.verbose: + logging.basicConfig(level=logging.DEBUG) + +data = np.load(args.data) train = data['desc_tr'] test = data['desc_te'] train_part = data['desc_sel'].T +logging.debug("Verbose is on") -if (KMEANS): - print("Computing KMeans with", train_part.shape[0], "keywords") - kmeans = KMeans(n_clusters=CLUSTER_CNT, n_init=N_ESTIMATORS, random_state=0).fit(train_part) +if (args.kmean): + logging.debug("Computing KMeans with", train_part.shape[0], "keywords") + kmeans = KMeans(n_clusters=args.kmean, n_init=args.estimators, random_state=0).fit(train_part) else: - trees = RandomTreesEmbedding(max_leaf_nodes=CLUSTER_CNT, n_estimators=N_ESTIMATORS, random_state=0).fit(train_part) + trees = RandomTreesEmbedding(max_leaf_nodes=args.leaves, n_estimators=args.estimators, random_state=0).fit(train_part) -print("Generating histograms") +logging.debug("Generating histograms") def make_histogram(data): - histogram = np.zeros((data.shape[0], data.shape[1],CLUSTER_CNT*N_ESTIMATORS)) + if args.kmean: + hist_size = args.estimators*args.kmean + else: + hist_size = args.estimators*args.leaves + + histogram = np.zeros((data.shape[0], data.shape[1],hist_size)) for i in range(data.shape[0]): for j in range(data.shape[1]): - if (KMEANS): - histogram[i][j] = np.bincount(kmeans.predict(data[i][j].T),minlength=CLUSTER_CNT) + if (args.kmean): + histogram[i][j] = np.bincount(kmeans.predict(data[i][j].T),minlength=args.kmean) else: leaves = trees.apply(data[i][j].T) - leaves = np.apply_along_axis(np.bincount, axis=0, arr=leaves, minlength=CLUSTER_CNT) - histogram[i][j] = leaves.reshape(CLUSTER_CNT*N_ESTIMATORS) + leaves = np.apply_along_axis(np.bincount, axis=0, arr=leaves, minlength=args.leaves) + histogram[i][j] = leaves.reshape(hist_size) return histogram hist_train = make_histogram(train) hist_test = make_histogram(test) -print("Keywords shape", hist_train.shape, "\n") -print("Planting trees...") +logging.debug("Keywords shape", hist_train.shape, "\n") +logging.debug("Planting trees...") clf = RandomForestClassifier() clf.fit( hist_train.reshape((hist_train.shape[0]*hist_train.shape[1], hist_train.shape[2])), np.repeat(np.arange(hist_train.shape[0]), hist_train.shape[1])) -print("Random forests created") +logging.debug("Random forests created") test_pred = clf.predict(hist_test.reshape((hist_test.shape[0]*hist_test.shape[1], hist_test.shape[2]))) test_label = np.repeat(np.arange(hist_test.shape[0]), hist_test.shape[1]) -skplt.metrics.plot_confusion_matrix(test_pred, test_label, normalize=True) -plt.show() +print(accuracy_score(test_pred, test_label)) + +if args.conf_mat: + skplt.metrics.plot_confusion_matrix(test_pred, test_label, normalize=True) + plt.show() -- cgit v1.2.3-54-g00ecf