From 77bb4328c094083f9647ea75eea6f16d70443749 Mon Sep 17 00:00:00 2001 From: nunzip Date: Mon, 3 Dec 2018 15:26:10 +0000 Subject: Dispay picture name from ranklist --- part2.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) (limited to 'part2.py') diff --git a/part2.py b/part2.py index ad82c37..e9231d8 100755 --- a/part2.py +++ b/part2.py @@ -42,6 +42,7 @@ parser.add_argument("-kb", "--rerankb", help="Parameter 2 for rerank", type=int, parser.add_argument("-n", "--neighbors", help="Number of neighbors", type=int, default = 1) parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true') parser.add_argument("-i", "--inrank", help="Checks Accuracy based on presence of label in ranklist", action='store_true', default=0) +parser.add_argument("-s", "--showrank", help="Save ranklist pic id in a txt file", type=int, default = 0) args = parser.parse_args() @@ -62,7 +63,7 @@ def draw_results(test_label, pred_label): plt.show() return -def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam, probe_cam, args): +def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam, probe_cam, showfiles_train, showfiles_test, args): # metric = 'jaccard' is also valid if args.mahalanobis: metric = 'sqeuclidean' @@ -83,6 +84,10 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam target_pred = np.zeros(ranklist.shape[0]) nneighbors = np.zeros((ranklist.shape[0],args.neighbors)) + nnshowrank = (np.zeros((ranklist.shape[0],args.neighbors))).astype(object) + + print(showfiles_train.shape) + print(showfiles_train.dtype) for probe_idx in range(probe_data.shape[0]): row = ranklist[probe_idx] n = 0 @@ -92,7 +97,7 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam probe_label[probe_idx] == gallery_label[row[n]]): n += 1 nneighbors[probe_idx][q] = gallery_label[row[n]] - #NEED TO ADD NNIDX SAME WAY TO PRINT ACTUAL RANKLIST + nnshowrank[probe_idx][q] = showfiles_train[row[n]] # q += 1 n += 1 if (args.inrank): @@ -100,9 +105,15 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam target_pred[probe_idx] = probe_label[probe_idx] else: target_pred[probe_idx] = nneighbors[probe_idx][0] + else: target_pred[probe_idx] = nneighbors[probe_idx][0] - #target_pred[probe_idx] = np.argmax(np.bincount(nneighbors[probe_idx].astype(int))) ###BIN LABELS FROM NN + + if (args.showrank): + with open("ranklist.txt", "w") as text_file: + text_file.write(np.array2string(nnshowrank[:args.showrank])) + with open("query.txt", "w") as text_file: + text_file.write(np.array2string(showfiles_test[:args.showrank])) return target_pred @@ -124,7 +135,9 @@ def main(): else: query_idx = query_idx.reshape(query_idx.shape[0]) camId = camId.reshape(camId.shape[0]) - + + showfiles_train = filelist[gallery_idx] + showfiles_test = filelist[query_idx] train_data = feature_vectors[gallery_idx] test_data = feature_vectors[query_idx] train_label = labels[gallery_idx] @@ -163,7 +176,7 @@ def main(): draw_results(test_label_1, target_pred) else: - target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, args) + target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, args) draw_results(test_label, target_pred) if __name__ == "__main__": -- cgit v1.2.3-54-g00ecf