From 769a50e70ac253531229e1639db6bc9e401a0c43 Mon Sep 17 00:00:00 2001 From: nunzip Date: Wed, 5 Dec 2018 15:50:53 +0000 Subject: Fix speed of multrank --- part2.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/part2.py b/part2.py index 9a30aad..82a8cdd 100755 --- a/part2.py +++ b/part2.py @@ -31,7 +31,7 @@ from scipy.spatial.distance import cdist from rerank import re_ranking parser = argparse.ArgumentParser() -parser.add_argument("-t", "--test", help="Use test data instead of query", action='store_true') +parser.add_argument("-t", "--train", help="Use test data instead of query", action='store_true') parser.add_argument("-c", "--conf_mat", help="Show visual confusion matrix", action='store_true') parser.add_argument("-k", "--kmean", help="Perform Kmeans", action='store_true', default=0) parser.add_argument("-m", "--mahalanobis", help="Perform Mahalanobis Distance metric", action='store_true', default=0) @@ -79,11 +79,7 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam else: if args.mahalanobis: # metric = 'jaccard' is also valid - covmat = LA.inv(np.cov(gallery_data)) - distances = np.zeros((probe_label.size, gallery_label.size)) - for i in range(probe_label.size): - print('Mahalanobis step ', i, '/', probe_label.size) - distances[i] = cdist(probe_data[i].reshape((1,2048)), gallery_data, 'mahalanobis', VI = covmat) + distances = cdist(probe_data, gallery_data, 'sqeuclidean') else: distances = cdist(probe_data, gallery_data, 'euclidean') @@ -147,12 +143,12 @@ def main(): train_idx = mat['train_idx'] - 1 with open("data/feature_data.json", "r") as read_file: feature_vectors = np.array(json.load(read_file)) - - gallery_idx = gallery_idx.reshape(gallery_idx.shape[0]) - if args.test: + if args.train: query_idx = train_idx.reshape(train_idx.shape[0]) + gallery_idx = train_idx.reshape(train_idx.shape[0]) else: query_idx = query_idx.reshape(query_idx.shape[0]) + gallery_idx = gallery_idx.reshape(gallery_idx.shape[0]) camId = camId.reshape(camId.shape[0]) showfiles_train = filelist[gallery_idx] -- cgit v1.2.3-54-g00ecf