From 5bf005f123a3ba9c90544654cea38b98603adb9a Mon Sep 17 00:00:00 2001 From: nunzip Date: Sat, 8 Dec 2018 18:38:21 +0000 Subject: Kmean fixed --- evaluate.py | 40 +++++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/evaluate.py b/evaluate.py index 6561b81..ebfc34e 100755 --- a/evaluate.py +++ b/evaluate.py @@ -14,6 +14,7 @@ import json import scipy.io from random import randint from sklearn.neighbors import KNeighborsClassifier +from sklearn.neighbors import NearestNeighbors from sklearn.neighbors import DistanceMetric from sklearn.cluster import KMeans from sklearn.decomposition import PCA @@ -41,8 +42,8 @@ parser.add_argument("-k", "--kmean", help="Perform Kmeans", action='store_true', parser.add_argument("-m", "--mahalanobis", help="Perform Mahalanobis Distance metric", action='store_true', default=0) parser.add_argument("-e", "--euclidean", help="Standard euclidean", action='store_true', default=0) parser.add_argument("-r", "--rerank", help="Use k-reciprocal rernaking", action='store_true') -parser.add_argument("-p", "--reranka", help="Parameter 1 for Rerank", type=int, default = 20) -parser.add_argument("-q", "--rerankb", help="Parameter 2 for rerank", type=int, default = 6) +parser.add_argument("-p", "--reranka", help="Parameter 1 for Rerank", type=int, default = 11) +parser.add_argument("-q", "--rerankb", help="Parameter 2 for rerank", type=int, default = 3) parser.add_argument("-l", "--rerankl", help="Coefficient to combine distances", type=int, default = 0.3) parser.add_argument("-n", "--neighbors", help="Number of neighbors", type=int, default = 1) parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true') @@ -53,6 +54,7 @@ parser.add_argument("-M", "--multrank", help="Run for different ranklist sizes e parser.add_argument("-C", "--comparison", help="Set to 2 to obtain a comparison of baseline and Improved metric", type=int, default=1) parser.add_argument("--data", help="Data folder with features data", default='data') parser.add_argument("-V", "--validation", help="Validation Mode", action='store_true') +parser.add_argument("-K", "--kmean_alt", help="Validation Mode", action='store_true') @@ -98,7 +100,6 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam nneighbors = np.zeros((ranklist.shape[0],nsize)) nnshowrank = (np.zeros((ranklist.shape[0],nsize))).astype(object) - for i in range(args.multrank): if args.multrank!= 1: args.neighbors = test_table[i] @@ -192,17 +193,34 @@ def main(): if (args.normalise): debug("Normalising data") - train_data = np.divide(train_data,LA.norm(train_data, axis=0)) - test_data = np.divide(test_data, LA.norm(test_data, axis=0)) + train_data = np.divide(train_data,LA.norm(train_data,axis=0)) + test_data = np.divide(test_data, LA.norm(test_data,axis=0)) if(args.kmean): debug("Using Kmeans") train_data, train_label, train_cam = create_kmean_clusters(feature_vectors, labels,gallery_idx,camId) - for q in range(args.comparison): - target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, args) - for i in range(args.multrank): - accuracy[q][i] = draw_results(test_label, target_pred[i]) - args.rerank = True - args.neighbors = 1 + + if args.kmean_alt: + kmeans = KMeans(n_clusters=10, random_state=0).fit(train_data) + neigh = NearestNeighbors(n_neighbors=1) + neigh.fit(kmeans.cluster_centers_) + neighbors = neigh.kneighbors(test_data, return_distance=False) + target_pred = np.zeros(test_data.shape[0]) + for i in range(10): + print(train_label[np.where(kmeans.labels_==i)].shape) + for i in range(test_data.shape[0]): + td = test_data[i].reshape(1,test_data.shape[1]) + tc = np.array([test_cam[i]]) + tl = np.array([test_label[i]]) + target_pred[i] = (test_model(train_data[np.where(kmeans.labels_==neighbors[i])], td, train_label[np.where(kmeans.labels_==neighbors[i])], tl, train_cam[np.where(kmeans.labels_==neighbors[i])], tc, showfiles_train[np.where(kmeans.labels_==neighbors[i])], showfiles_test[i], args)) + + accuracy[0] = draw_results(test_label, target_pred) + else: + for q in range(args.comparison): + target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, args) + for i in range(args.multrank): + accuracy[q][i] = draw_results(test_label, target_pred[i]) + args.rerank = True + args.neighbors = 1 if(args.multrank != 1): plt.plot(test_table[:(args.multrank)], 100*accuracy[0]) -- cgit v1.2.3-54-g00ecf