From a1985d4f28fdf596f4c223b34598f9d82eb5a154 Mon Sep 17 00:00:00 2001 From: nunzip Date: Sun, 9 Dec 2018 16:31:22 +0000 Subject: Fix mAP --- evaluate.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/evaluate.py b/evaluate.py index ca6c8db..9218649 100755 --- a/evaluate.py +++ b/evaluate.py @@ -134,20 +134,33 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam if args.mAP: precision = np.zeros((probe_label.shape[0], args.neighbors)) + recall = np.zeros((probe_label.shape[0], args.neighbors)) mAP = np.zeros(probe_label.shape[0]) - truths = precision + max_level_precision = np.zeros((probe_label.shape[0],11)) + for i in range(probe_label.shape[0]): truth_count=0 false_count=0 for j in range(args.neighbors): if probe_label[i] == nneighbors[i][j]: - truths[i][j] = 1 truth_count+=1 precision[i][j] = truth_count/(j+1) - mAP[i] += precision[i][j] + else: + false_count+=1 + precision[i][j]= 1 - false_count/(j+1) if truth_count!=0: - mAP[i] = mAP[i]/truth_count + recall_step = 1/truth_count + for j in range(args.neighbors): + if probe_label[i] == nneighbors[i][j]: + recall[i][j:] += recall_step + else: + recall[i][:] = 1 + for i in range(probe_label.shape[0]): + for j in range(11): + max_level_precision[i][j] = np.max(precision[i][np.where(recall[i]>=(j/10))]) #print(mAP[i]) + for i in range(probe_label.shape[0]): + mAP[i] = sum(max_level_precision[i])/11 print('mAP:',np.mean(mAP)) return target_pred -- cgit v1.2.3-54-g00ecf