From a8573c29ba5410b9b9e02b0c6624525447114fd6 Mon Sep 17 00:00:00 2001 From: nunzip Date: Tue, 4 Dec 2018 15:28:33 +0000 Subject: Rewrite Mahalanobis --- part2.py | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) (limited to 'part2.py') diff --git a/part2.py b/part2.py index 78cc29d..7bb72d2 100755 --- a/part2.py +++ b/part2.py @@ -43,6 +43,9 @@ parser.add_argument("-n", "--neighbors", help="Number of neighbors", type=int, d parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true') parser.add_argument("-i", "--inrank", help="Checks Accuracy based on presence of label in ranklist", action='store_true', default=0) parser.add_argument("-s", "--showrank", help="Save ranklist pic id in a txt file", type=int, default = 0) +parser.add_argument("-2", "--graphspace", help="Graph space", action='store_true', default=0) +parser.add_argument("-1", "--norm", help="Normalized features", action='store_true', default=0) + args = parser.parse_args() @@ -64,21 +67,24 @@ def draw_results(test_label, pred_label): return def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam, probe_cam, showfiles_train, showfiles_test, args): - # metric = 'jaccard' is also valid - if args.mahalanobis: - metric = 'sqeuclidean' - else: - metric = 'euclidean' - + verbose("probe shape:", probe_data.shape) verbose("gallery shape:", gallery_data.shape) - + if args.rerank: distances = re_ranking(probe_data, gallery_data, args.reranka ,args.rerankb , 0.3, MemorySave = False, Minibatch = 2000) else: - distances = cdist(probe_data, gallery_data, metric) + if args.mahalanobis: + # metric = 'jaccard' is also valid + covmat = LA.inv(np.cov(gallery_data)) + distances = np.zeros((probe_label.size, gallery_label.size)) + for i in range(probe_label.size): + print('Mahalanobis step ', i, '/', probe_label.size) + distances [i] = cdist(probe_data[i].reshape((1,2048)), gallery_data, 'mahalanobis', VI = covmat) + else: + distances = cdist(probe_data, gallery_data, 'euclidean') ranklist = np.argsort(distances, axis=1) @@ -106,13 +112,21 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam else: target_pred[probe_idx] = nneighbors[probe_idx][0] - + if (args.showrank): with open("ranklist.txt", "w") as text_file: text_file.write(np.array2string(nnshowrank[:args.showrank])) with open("query.txt", "w") as text_file: text_file.write(np.array2string(showfiles_test[:args.showrank])) - + if args.graphspace: + # Colors for distinct individuals + cols = ['#{:06x}'.format(randint(0, 0xffffff)) for i in range(1467)] + gallery_label_tmp = np.subtract(gallery_label, 1) + pltCol = [cols[int(k)] for k in gallery_label_tmp] + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + ax.scatter(gallery_data[:, 0], gallery_data[:, 1], gallery_data[:, 2], marker='o', color=pltCol) + plt.show() return target_pred def main(): @@ -141,7 +155,9 @@ def main(): test_label = labels[query_idx] train_cam = camId[gallery_idx] test_cam = camId[query_idx] - + if (args.norm): + train_data = np.divide(train_data,LA.norm(train_data, axis=0)) + test_data = np.divide(test_data, LA.norm(test_data, axis=0)) if(args.kmean): gallery1 = [] gallery2 = [] @@ -200,7 +216,6 @@ def main(): cluster = np.array(cl) clusterlabel = np.array(cllab) clustercam = np.array(clcam) - print(cluster.shape, clusterlabel.shape, clustercam.shape) target_pred = test_model(cluster, test_data, clusterlabel, test_label, clustercam, test_cam, showfiles_train, showfiles_test, args) draw_results(test_label, target_pred) -- cgit v1.2.3-54-g00ecf