aboutsummaryrefslogtreecommitdiff
path: root/part2.py
diff options
context:
space:
mode:
authornunzip <np_scarh@e4-pattern-vm.europe-west4-a.c.electric-orbit-223819.internal>2018-12-05 15:50:53 +0000
committernunzip <np_scarh@e4-pattern-vm.europe-west4-a.c.electric-orbit-223819.internal>2018-12-05 15:50:53 +0000
commit769a50e70ac253531229e1639db6bc9e401a0c43 (patch)
treea07c93988257e6340f551901aba5331108cde888 /part2.py
parent390bc568d1b453da960569b26b361c338cf22e2c (diff)
downloadvz215_np1915-769a50e70ac253531229e1639db6bc9e401a0c43.tar.gz
vz215_np1915-769a50e70ac253531229e1639db6bc9e401a0c43.tar.bz2
vz215_np1915-769a50e70ac253531229e1639db6bc9e401a0c43.zip
Fix speed of multrank
Diffstat (limited to 'part2.py')
-rwxr-xr-xpart2.py14
1 files changed, 5 insertions, 9 deletions
diff --git a/part2.py b/part2.py
index 9a30aad..82a8cdd 100755
--- a/part2.py
+++ b/part2.py
@@ -31,7 +31,7 @@ from scipy.spatial.distance import cdist
from rerank import re_ranking
parser = argparse.ArgumentParser()
-parser.add_argument("-t", "--test", help="Use test data instead of query", action='store_true')
+parser.add_argument("-t", "--train", help="Use test data instead of query", action='store_true')
parser.add_argument("-c", "--conf_mat", help="Show visual confusion matrix", action='store_true')
parser.add_argument("-k", "--kmean", help="Perform Kmeans", action='store_true', default=0)
parser.add_argument("-m", "--mahalanobis", help="Perform Mahalanobis Distance metric", action='store_true', default=0)
@@ -79,11 +79,7 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam
else:
if args.mahalanobis:
# metric = 'jaccard' is also valid
- covmat = LA.inv(np.cov(gallery_data))
- distances = np.zeros((probe_label.size, gallery_label.size))
- for i in range(probe_label.size):
- print('Mahalanobis step ', i, '/', probe_label.size)
- distances[i] = cdist(probe_data[i].reshape((1,2048)), gallery_data, 'mahalanobis', VI = covmat)
+ distances = cdist(probe_data, gallery_data, 'sqeuclidean')
else:
distances = cdist(probe_data, gallery_data, 'euclidean')
@@ -147,12 +143,12 @@ def main():
train_idx = mat['train_idx'] - 1
with open("data/feature_data.json", "r") as read_file:
feature_vectors = np.array(json.load(read_file))
-
- gallery_idx = gallery_idx.reshape(gallery_idx.shape[0])
- if args.test:
+ if args.train:
query_idx = train_idx.reshape(train_idx.shape[0])
+ gallery_idx = train_idx.reshape(train_idx.shape[0])
else:
query_idx = query_idx.reshape(query_idx.shape[0])
+ gallery_idx = gallery_idx.reshape(gallery_idx.shape[0])
camId = camId.reshape(camId.shape[0])
showfiles_train = filelist[gallery_idx]