diff options
-rwxr-xr-x | evaluate.py | 33 |
1 files changed, 21 insertions, 12 deletions
diff --git a/evaluate.py b/evaluate.py index caed403..ca6c8db 100755 --- a/evaluate.py +++ b/evaluate.py @@ -48,13 +48,13 @@ parser.add_argument("-l", "--rerankl", help="Coefficient to combine distances", parser.add_argument("-n", "--neighbors", help="Number of neighbors", type=int, default = 1) parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true') parser.add_argument("-s", "--showrank", help="Save ranklist pic id in a txt file", type=int, default = 0) -parser.add_argument("-2", "--graphspace", help="Graph space", action='store_true', default=0) parser.add_argument("-1", "--normalise", help="Normalized features", action='store_true', default=0) parser.add_argument("-M", "--multrank", help="Run for different ranklist sizes equal to M", type=int, default=1) parser.add_argument("-C", "--comparison", help="Set to 2 to obtain a comparison of baseline and Improved metric", type=int, default=1) parser.add_argument("--data", help="Data folder with features data", default='data') parser.add_argument("-V", "--validation", help="Validation Mode", action='store_true') -parser.add_argument("-K", "--kmean_alt", help="Validation Mode", type=int, default=0) +parser.add_argument("-K", "--kmean_alt", help="Kmean", type=int, default=0) +parser.add_argument("-P", "--mAP", help="Mean Average Precision", action='store_true') @@ -131,16 +131,25 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam text_file.write(np.array2string(nnshowrank[:args.showrank])) with open("query.txt", "w") as text_file: text_file.write(np.array2string(showfiles_test[:args.showrank])) - - if args.graphspace: - # Colors for distinct individuals - cols = ['#{:06x}'.format(randint(0, 0xffffff)) for i in range(1467)] - gallery_label_tmp = np.subtract(gallery_label, 1) - pltCol = [cols[int(k)] for k in gallery_label_tmp] - fig = plt.figure() - ax = fig.add_subplot(111, projection='3d') - ax.scatter(gallery_data[:, 0], gallery_data[:, 1], gallery_data[:, 2], marker='o', color=pltCol) - plt.show() + + if args.mAP: + precision = np.zeros((probe_label.shape[0], args.neighbors)) + mAP = np.zeros(probe_label.shape[0]) + truths = precision + for i in range(probe_label.shape[0]): + truth_count=0 + false_count=0 + for j in range(args.neighbors): + if probe_label[i] == nneighbors[i][j]: + truths[i][j] = 1 + truth_count+=1 + precision[i][j] = truth_count/(j+1) + mAP[i] += precision[i][j] + if truth_count!=0: + mAP[i] = mAP[i]/truth_count + #print(mAP[i]) + print('mAP:',np.mean(mAP)) + return target_pred def main(): |