aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xpart2.py39
1 files changed, 27 insertions, 12 deletions
diff --git a/part2.py b/part2.py
index 78cc29d..7bb72d2 100755
--- a/part2.py
+++ b/part2.py
@@ -43,6 +43,9 @@ parser.add_argument("-n", "--neighbors", help="Number of neighbors", type=int, d
parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true')
parser.add_argument("-i", "--inrank", help="Checks Accuracy based on presence of label in ranklist", action='store_true', default=0)
parser.add_argument("-s", "--showrank", help="Save ranklist pic id in a txt file", type=int, default = 0)
+parser.add_argument("-2", "--graphspace", help="Graph space", action='store_true', default=0)
+parser.add_argument("-1", "--norm", help="Normalized features", action='store_true', default=0)
+
args = parser.parse_args()
@@ -64,21 +67,24 @@ def draw_results(test_label, pred_label):
return
def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam, probe_cam, showfiles_train, showfiles_test, args):
- # metric = 'jaccard' is also valid
- if args.mahalanobis:
- metric = 'sqeuclidean'
- else:
- metric = 'euclidean'
-
+
verbose("probe shape:", probe_data.shape)
verbose("gallery shape:", gallery_data.shape)
-
+
if args.rerank:
distances = re_ranking(probe_data, gallery_data,
args.reranka ,args.rerankb , 0.3,
MemorySave = False, Minibatch = 2000)
else:
- distances = cdist(probe_data, gallery_data, metric)
+ if args.mahalanobis:
+ # metric = 'jaccard' is also valid
+ covmat = LA.inv(np.cov(gallery_data))
+ distances = np.zeros((probe_label.size, gallery_label.size))
+ for i in range(probe_label.size):
+ print('Mahalanobis step ', i, '/', probe_label.size)
+ distances [i] = cdist(probe_data[i].reshape((1,2048)), gallery_data, 'mahalanobis', VI = covmat)
+ else:
+ distances = cdist(probe_data, gallery_data, 'euclidean')
ranklist = np.argsort(distances, axis=1)
@@ -106,13 +112,21 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam
else:
target_pred[probe_idx] = nneighbors[probe_idx][0]
-
+
if (args.showrank):
with open("ranklist.txt", "w") as text_file:
text_file.write(np.array2string(nnshowrank[:args.showrank]))
with open("query.txt", "w") as text_file:
text_file.write(np.array2string(showfiles_test[:args.showrank]))
-
+ if args.graphspace:
+ # Colors for distinct individuals
+ cols = ['#{:06x}'.format(randint(0, 0xffffff)) for i in range(1467)]
+ gallery_label_tmp = np.subtract(gallery_label, 1)
+ pltCol = [cols[int(k)] for k in gallery_label_tmp]
+ fig = plt.figure()
+ ax = fig.add_subplot(111, projection='3d')
+ ax.scatter(gallery_data[:, 0], gallery_data[:, 1], gallery_data[:, 2], marker='o', color=pltCol)
+ plt.show()
return target_pred
def main():
@@ -141,7 +155,9 @@ def main():
test_label = labels[query_idx]
train_cam = camId[gallery_idx]
test_cam = camId[query_idx]
-
+ if (args.norm):
+ train_data = np.divide(train_data,LA.norm(train_data, axis=0))
+ test_data = np.divide(test_data, LA.norm(test_data, axis=0))
if(args.kmean):
gallery1 = []
gallery2 = []
@@ -200,7 +216,6 @@ def main():
cluster = np.array(cl)
clusterlabel = np.array(cllab)
clustercam = np.array(clcam)
- print(cluster.shape, clusterlabel.shape, clustercam.shape)
target_pred = test_model(cluster, test_data, clusterlabel, test_label, clustercam, test_cam, showfiles_train, showfiles_test, args)
draw_results(test_label, target_pred)