diff options
author | nunzip <np_scarh@e4-pattern-vm.europe-west4-a.c.electric-orbit-223819.internal> | 2018-12-03 14:43:32 +0000 |
---|---|---|
committer | nunzip <np_scarh@e4-pattern-vm.europe-west4-a.c.electric-orbit-223819.internal> | 2018-12-03 14:43:32 +0000 |
commit | 1def9d7690e29fa635f8b0022dfb41557fa1219d (patch) | |
tree | 4fe5eea55d2241da01e2f64b5debbf1a2a6ed913 | |
parent | 3e6e29e3ca3ea6744c9fb6909f481cb262d512e4 (diff) | |
download | vz215_np1915-1def9d7690e29fa635f8b0022dfb41557fa1219d.tar.gz vz215_np1915-1def9d7690e29fa635f8b0022dfb41557fa1219d.tar.bz2 vz215_np1915-1def9d7690e29fa635f8b0022dfb41557fa1219d.zip |
Add ranking list classification
-rwxr-xr-x | part2.py | 32 |
1 files changed, 25 insertions, 7 deletions
@@ -39,7 +39,10 @@ parser.add_argument("-e", "--euclidean", help="Standard euclidean", action='stor parser.add_argument("-r", "--rerank", help="Use k-reciprocal rernaking", action='store_true') parser.add_argument("-ka", "--reranka", help="Parameter 1 for Rerank", type=int, default = 20) parser.add_argument("-kb", "--rerankb", help="Parameter 2 for rerank", type=int, default = 6) +parser.add_argument("-n", "--neighbors", help="Number of neighbors", type=int, default = 1) parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true') +parser.add_argument("-i", "--inrank", help="Checks Accuracy based on presence of label in ranklist", action='store_true', default=0) + args = parser.parse_args() def verbose(*text): @@ -59,7 +62,7 @@ def draw_results(test_label, pred_label): plt.show() return -def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam, probe_cam): +def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam, probe_cam, args): # metric = 'jaccard' is also valid if args.mahalanobis: metric = 'sqeuclidean' @@ -79,14 +82,29 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam ranklist = np.argsort(distances, axis=1) target_pred = np.zeros(ranklist.shape[0]) + nneighbors = np.zeros((ranklist.shape[0],args.neighbors)) for probe_idx in range(probe_data.shape[0]): row = ranklist[probe_idx] - n = 0 - while (probe_cam[probe_idx] == gallery_cam[row[n]] and - probe_label[probe_idx] == gallery_label[row[n]]): + n = 0 + q = 0 + while (q < args.neighbors): + while (probe_cam[probe_idx] == gallery_cam[row[n]] and + probe_label[probe_idx] == gallery_label[row[n]]): + n += 1 + nneighbors[probe_idx][q] = gallery_label[row[n]] + #NEED TO ADD NNIDX SAME WAY TO PRINT ACTUAL RANKLIST + q += 1 n += 1 - target_pred[probe_idx] = gallery_label[row[n]] - + if (args.inrank): + if (probe_label[probe_idx] in nneighbors[probe_idx]): + target_pred[probe_idx] = probe_label[probe_idx] + else: + target_pred[probe_idx] = nneighbors[probe_idx][0] + else: + target_pred[probe_idx] = nneighbors[probe_idx][0] + #target_pred[probe_idx] = np.argmax(np.bincount(nneighbors[probe_idx].astype(int))) ###BIN LABELS FROM NN + + return target_pred def main(): @@ -145,7 +163,7 @@ def main(): draw_results(test_label_1, target_pred) else: - target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam) + target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, args) draw_results(test_label, target_pred) if __name__ == "__main__": |