diff options
author | nunzip <np_scarh@e4-pattern-vm.europe-west4-a.c.electric-orbit-223819.internal> | 2018-12-04 15:28:33 +0000 |
---|---|---|
committer | nunzip <np_scarh@e4-pattern-vm.europe-west4-a.c.electric-orbit-223819.internal> | 2018-12-04 15:28:33 +0000 |
commit | a8573c29ba5410b9b9e02b0c6624525447114fd6 (patch) | |
tree | 804b3db25879e599ebc26363118ca578a812540d | |
parent | a0b7adfd4a7df4079a4ae73729f9ec0feb6eb8c8 (diff) | |
download | vz215_np1915-a8573c29ba5410b9b9e02b0c6624525447114fd6.tar.gz vz215_np1915-a8573c29ba5410b9b9e02b0c6624525447114fd6.tar.bz2 vz215_np1915-a8573c29ba5410b9b9e02b0c6624525447114fd6.zip |
Rewrite Mahalanobis
-rwxr-xr-x | part2.py | 39 |
1 files changed, 27 insertions, 12 deletions
@@ -43,6 +43,9 @@ parser.add_argument("-n", "--neighbors", help="Number of neighbors", type=int, d parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true') parser.add_argument("-i", "--inrank", help="Checks Accuracy based on presence of label in ranklist", action='store_true', default=0) parser.add_argument("-s", "--showrank", help="Save ranklist pic id in a txt file", type=int, default = 0) +parser.add_argument("-2", "--graphspace", help="Graph space", action='store_true', default=0) +parser.add_argument("-1", "--norm", help="Normalized features", action='store_true', default=0) + args = parser.parse_args() @@ -64,21 +67,24 @@ def draw_results(test_label, pred_label): return def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam, probe_cam, showfiles_train, showfiles_test, args): - # metric = 'jaccard' is also valid - if args.mahalanobis: - metric = 'sqeuclidean' - else: - metric = 'euclidean' - + verbose("probe shape:", probe_data.shape) verbose("gallery shape:", gallery_data.shape) - + if args.rerank: distances = re_ranking(probe_data, gallery_data, args.reranka ,args.rerankb , 0.3, MemorySave = False, Minibatch = 2000) else: - distances = cdist(probe_data, gallery_data, metric) + if args.mahalanobis: + # metric = 'jaccard' is also valid + covmat = LA.inv(np.cov(gallery_data)) + distances = np.zeros((probe_label.size, gallery_label.size)) + for i in range(probe_label.size): + print('Mahalanobis step ', i, '/', probe_label.size) + distances [i] = cdist(probe_data[i].reshape((1,2048)), gallery_data, 'mahalanobis', VI = covmat) + else: + distances = cdist(probe_data, gallery_data, 'euclidean') ranklist = np.argsort(distances, axis=1) @@ -106,13 +112,21 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam else: target_pred[probe_idx] = nneighbors[probe_idx][0] - + if (args.showrank): with open("ranklist.txt", "w") as text_file: text_file.write(np.array2string(nnshowrank[:args.showrank])) with open("query.txt", "w") as text_file: text_file.write(np.array2string(showfiles_test[:args.showrank])) - + if args.graphspace: + # Colors for distinct individuals + cols = ['#{:06x}'.format(randint(0, 0xffffff)) for i in range(1467)] + gallery_label_tmp = np.subtract(gallery_label, 1) + pltCol = [cols[int(k)] for k in gallery_label_tmp] + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + ax.scatter(gallery_data[:, 0], gallery_data[:, 1], gallery_data[:, 2], marker='o', color=pltCol) + plt.show() return target_pred def main(): @@ -141,7 +155,9 @@ def main(): test_label = labels[query_idx] train_cam = camId[gallery_idx] test_cam = camId[query_idx] - + if (args.norm): + train_data = np.divide(train_data,LA.norm(train_data, axis=0)) + test_data = np.divide(test_data, LA.norm(test_data, axis=0)) if(args.kmean): gallery1 = [] gallery2 = [] @@ -200,7 +216,6 @@ def main(): cluster = np.array(cl) clusterlabel = np.array(cllab) clustercam = np.array(clcam) - print(cluster.shape, clusterlabel.shape, clustercam.shape) target_pred = test_model(cluster, test_data, clusterlabel, test_label, clustercam, test_cam, showfiles_train, showfiles_test, args) draw_results(test_label, target_pred) |