aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xevaluate.py67
1 files changed, 41 insertions, 26 deletions
diff --git a/evaluate.py b/evaluate.py
index 2de8f8c..a7a3ab7 100755
--- a/evaluate.py
+++ b/evaluate.py
@@ -2,64 +2,79 @@
# EE4 Selected Topics From Computer Vision Coursework
# Vasil Zlatanov, Nunzio Pucci
-DATA_FILE = 'data.npz'
-CLUSTER_CNT = 256
-KMEANS = False
-
-if KMEANS:
- N_ESTIMATORS = 1
-else:
- N_ESTIMATORS = 100
-
import numpy as np
import matplotlib.pyplot as plt
import scikitplot as skplt
-
+import argparse
+from timeit import default_timer as timer
+import logging
+from logging import debug
from sklearn.cluster import KMeans
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomTreesEmbedding
+from sklearn.metrics import accuracy_score
-data = np.load(DATA_FILE)
+parser = argparse.ArgumentParser()
+parser.add_argument("-d", "--data", help="Data path", action='store_true', default='data.npz')
+parser.add_argument("-c", "--conf_mat", help="Show visual confusion matrix", action='store_true')
+parser.add_argument("-k", "--kmean", help="Perform kmean clustering with --kmean cluster centers", type=int, default=0)
+parser.add_argument("-l", "--leaves", help="Maximum leaf nodes for RF classifier", type=int, default=256)
+parser.add_argument("-e", "--estimators", help="number of estimators to be used", type=int, default=100)
+parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true')
+args = parser.parse_args()
+if args.verbose:
+ logging.basicConfig(level=logging.DEBUG)
+
+data = np.load(args.data)
train = data['desc_tr']
test = data['desc_te']
train_part = data['desc_sel'].T
+logging.debug("Verbose is on")
-if (KMEANS):
- print("Computing KMeans with", train_part.shape[0], "keywords")
- kmeans = KMeans(n_clusters=CLUSTER_CNT, n_init=N_ESTIMATORS, random_state=0).fit(train_part)
+if (args.kmean):
+ logging.debug("Computing KMeans with", train_part.shape[0], "keywords")
+ kmeans = KMeans(n_clusters=args.kmean, n_init=args.estimators, random_state=0).fit(train_part)
else:
- trees = RandomTreesEmbedding(max_leaf_nodes=CLUSTER_CNT, n_estimators=N_ESTIMATORS, random_state=0).fit(train_part)
+ trees = RandomTreesEmbedding(max_leaf_nodes=args.leaves, n_estimators=args.estimators, random_state=0).fit(train_part)
-print("Generating histograms")
+logging.debug("Generating histograms")
def make_histogram(data):
- histogram = np.zeros((data.shape[0], data.shape[1],CLUSTER_CNT*N_ESTIMATORS))
+ if args.kmean:
+ hist_size = args.estimators*args.kmean
+ else:
+ hist_size = args.estimators*args.leaves
+
+ histogram = np.zeros((data.shape[0], data.shape[1],hist_size))
for i in range(data.shape[0]):
for j in range(data.shape[1]):
- if (KMEANS):
- histogram[i][j] = np.bincount(kmeans.predict(data[i][j].T),minlength=CLUSTER_CNT)
+ if (args.kmean):
+ histogram[i][j] = np.bincount(kmeans.predict(data[i][j].T),minlength=args.kmean)
else:
leaves = trees.apply(data[i][j].T)
- leaves = np.apply_along_axis(np.bincount, axis=0, arr=leaves, minlength=CLUSTER_CNT)
- histogram[i][j] = leaves.reshape(CLUSTER_CNT*N_ESTIMATORS)
+ leaves = np.apply_along_axis(np.bincount, axis=0, arr=leaves, minlength=args.leaves)
+ histogram[i][j] = leaves.reshape(hist_size)
return histogram
hist_train = make_histogram(train)
hist_test = make_histogram(test)
-print("Keywords shape", hist_train.shape, "\n")
-print("Planting trees...")
+logging.debug("Keywords shape", hist_train.shape, "\n")
+logging.debug("Planting trees...")
clf = RandomForestClassifier()
clf.fit(
hist_train.reshape((hist_train.shape[0]*hist_train.shape[1], hist_train.shape[2])),
np.repeat(np.arange(hist_train.shape[0]), hist_train.shape[1]))
-print("Random forests created")
+logging.debug("Random forests created")
test_pred = clf.predict(hist_test.reshape((hist_test.shape[0]*hist_test.shape[1], hist_test.shape[2])))
test_label = np.repeat(np.arange(hist_test.shape[0]), hist_test.shape[1])
-skplt.metrics.plot_confusion_matrix(test_pred, test_label, normalize=True)
-plt.show()
+print(accuracy_score(test_pred, test_label))
+
+if args.conf_mat:
+ skplt.metrics.plot_confusion_matrix(test_pred, test_label, normalize=True)
+ plt.show()