aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xevaluate.py40
1 files changed, 29 insertions, 11 deletions
diff --git a/evaluate.py b/evaluate.py
index 38b20bb..7ce586b 100755
--- a/evaluate.py
+++ b/evaluate.py
@@ -14,6 +14,7 @@ import json
import scipy.io
from random import randint
from sklearn.neighbors import KNeighborsClassifier
+from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import DistanceMetric
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
@@ -41,8 +42,8 @@ parser.add_argument("-k", "--kmean", help="Perform Kmeans", action='store_true',
parser.add_argument("-m", "--mahalanobis", help="Perform Mahalanobis Distance metric", action='store_true', default=0)
parser.add_argument("-e", "--euclidean", help="Standard euclidean", action='store_true', default=0)
parser.add_argument("-r", "--rerank", help="Use k-reciprocal rernaking", action='store_true')
-parser.add_argument("-p", "--reranka", help="Parameter 1 for Rerank", type=int, default = 20)
-parser.add_argument("-q", "--rerankb", help="Parameter 2 for rerank", type=int, default = 6)
+parser.add_argument("-p", "--reranka", help="Parameter 1 for Rerank", type=int, default = 11)
+parser.add_argument("-q", "--rerankb", help="Parameter 2 for rerank", type=int, default = 3)
parser.add_argument("-l", "--rerankl", help="Coefficient to combine distances", type=int, default = 0.3)
parser.add_argument("-n", "--neighbors", help="Number of neighbors", type=int, default = 1)
parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true')
@@ -53,6 +54,7 @@ parser.add_argument("-M", "--multrank", help="Run for different ranklist sizes e
parser.add_argument("-C", "--comparison", help="Set to 2 to obtain a comparison of baseline and Improved metric", type=int, default=1)
parser.add_argument("--data", help="Data folder with features data", default='data')
parser.add_argument("-V", "--validation", help="Validation Mode", action='store_true')
+parser.add_argument("-K", "--kmean_alt", help="Validation Mode", action='store_true')
@@ -99,7 +101,6 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam
nneighbors = np.zeros((ranklist.shape[0],nsize))
nnshowrank = (np.zeros((ranklist.shape[0],nsize))).astype(object)
-
for i in range(args.multrank):
if args.multrank!= 1:
args.neighbors = test_table[i]
@@ -193,17 +194,34 @@ def main():
if (args.normalise):
debug("Normalising data")
- train_data = np.divide(train_data,LA.norm(train_data, axis=0))
- test_data = np.divide(test_data, LA.norm(test_data, axis=0))
+ train_data = np.divide(train_data,LA.norm(train_data,axis=0))
+ test_data = np.divide(test_data, LA.norm(test_data,axis=0))
if(args.kmean):
debug("Using Kmeans")
train_data, train_label, train_cam = create_kmean_clusters(feature_vectors, labels,gallery_idx,camId)
- for q in range(args.comparison):
- target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, args)
- for i in range(args.multrank):
- accuracy[q][i] = draw_results(test_label, target_pred[i])
- args.rerank = True
- args.neighbors = 1
+
+ if args.kmean_alt:
+ kmeans = KMeans(n_clusters=10, random_state=0).fit(train_data)
+ neigh = NearestNeighbors(n_neighbors=1)
+ neigh.fit(kmeans.cluster_centers_)
+ neighbors = neigh.kneighbors(test_data, return_distance=False)
+ target_pred = np.zeros(test_data.shape[0])
+ for i in range(10):
+ print(train_label[np.where(kmeans.labels_==i)].shape)
+ for i in range(test_data.shape[0]):
+ td = test_data[i].reshape(1,test_data.shape[1])
+ tc = np.array([test_cam[i]])
+ tl = np.array([test_label[i]])
+ target_pred[i] = (test_model(train_data[np.where(kmeans.labels_==neighbors[i])], td, train_label[np.where(kmeans.labels_==neighbors[i])], tl, train_cam[np.where(kmeans.labels_==neighbors[i])], tc, showfiles_train[np.where(kmeans.labels_==neighbors[i])], showfiles_test[i], args))
+
+ accuracy[0] = draw_results(test_label, target_pred)
+ else:
+ for q in range(args.comparison):
+ target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, args)
+ for i in range(args.multrank):
+ accuracy[q][i] = draw_results(test_label, target_pred[i])
+ args.rerank = True
+ args.neighbors = 1
if(args.multrank != 1):
plt.plot(test_table[:(args.multrank)], 100*accuracy[0])