diff options
author | nunzip <np_scarh@e4-pattern-vm.europe-west4-a.c.electric-orbit-223819.internal> | 2018-12-05 15:50:53 +0000 |
---|---|---|
committer | nunzip <np_scarh@e4-pattern-vm.europe-west4-a.c.electric-orbit-223819.internal> | 2018-12-05 15:50:53 +0000 |
commit | 769a50e70ac253531229e1639db6bc9e401a0c43 (patch) | |
tree | a07c93988257e6340f551901aba5331108cde888 | |
parent | 390bc568d1b453da960569b26b361c338cf22e2c (diff) | |
download | vz215_np1915-769a50e70ac253531229e1639db6bc9e401a0c43.tar.gz vz215_np1915-769a50e70ac253531229e1639db6bc9e401a0c43.tar.bz2 vz215_np1915-769a50e70ac253531229e1639db6bc9e401a0c43.zip |
Fix speed of multrank
-rwxr-xr-x | part2.py | 14 |
1 files changed, 5 insertions, 9 deletions
@@ -31,7 +31,7 @@ from scipy.spatial.distance import cdist from rerank import re_ranking parser = argparse.ArgumentParser() -parser.add_argument("-t", "--test", help="Use test data instead of query", action='store_true') +parser.add_argument("-t", "--train", help="Use test data instead of query", action='store_true') parser.add_argument("-c", "--conf_mat", help="Show visual confusion matrix", action='store_true') parser.add_argument("-k", "--kmean", help="Perform Kmeans", action='store_true', default=0) parser.add_argument("-m", "--mahalanobis", help="Perform Mahalanobis Distance metric", action='store_true', default=0) @@ -79,11 +79,7 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam else: if args.mahalanobis: # metric = 'jaccard' is also valid - covmat = LA.inv(np.cov(gallery_data)) - distances = np.zeros((probe_label.size, gallery_label.size)) - for i in range(probe_label.size): - print('Mahalanobis step ', i, '/', probe_label.size) - distances[i] = cdist(probe_data[i].reshape((1,2048)), gallery_data, 'mahalanobis', VI = covmat) + distances = cdist(probe_data, gallery_data, 'sqeuclidean') else: distances = cdist(probe_data, gallery_data, 'euclidean') @@ -147,12 +143,12 @@ def main(): train_idx = mat['train_idx'] - 1 with open("data/feature_data.json", "r") as read_file: feature_vectors = np.array(json.load(read_file)) - - gallery_idx = gallery_idx.reshape(gallery_idx.shape[0]) - if args.test: + if args.train: query_idx = train_idx.reshape(train_idx.shape[0]) + gallery_idx = train_idx.reshape(train_idx.shape[0]) else: query_idx = query_idx.reshape(query_idx.shape[0]) + gallery_idx = gallery_idx.reshape(gallery_idx.shape[0]) camId = camId.reshape(camId.shape[0]) showfiles_train = filelist[gallery_idx] |