aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornunzip <np_scarh@e4-pattern-vm.europe-west4-a.c.electric-orbit-223819.internal>2018-12-09 16:31:22 +0000
committernunzip <np_scarh@e4-pattern-vm.europe-west4-a.c.electric-orbit-223819.internal>2018-12-09 16:31:22 +0000
commita1985d4f28fdf596f4c223b34598f9d82eb5a154 (patch)
tree28999ede94f3b2063ea82bb9e46204d77f389669
parente04449ee1459b1f6d66ed16df733f54367fa43dd (diff)
downloadvz215_np1915-a1985d4f28fdf596f4c223b34598f9d82eb5a154.tar.gz
vz215_np1915-a1985d4f28fdf596f4c223b34598f9d82eb5a154.tar.bz2
vz215_np1915-a1985d4f28fdf596f4c223b34598f9d82eb5a154.zip
Fix mAP
-rwxr-xr-xevaluate.py21
1 files changed, 17 insertions, 4 deletions
diff --git a/evaluate.py b/evaluate.py
index ca6c8db..9218649 100755
--- a/evaluate.py
+++ b/evaluate.py
@@ -134,20 +134,33 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam
if args.mAP:
precision = np.zeros((probe_label.shape[0], args.neighbors))
+ recall = np.zeros((probe_label.shape[0], args.neighbors))
mAP = np.zeros(probe_label.shape[0])
- truths = precision
+ max_level_precision = np.zeros((probe_label.shape[0],11))
+
for i in range(probe_label.shape[0]):
truth_count=0
false_count=0
for j in range(args.neighbors):
if probe_label[i] == nneighbors[i][j]:
- truths[i][j] = 1
truth_count+=1
precision[i][j] = truth_count/(j+1)
- mAP[i] += precision[i][j]
+ else:
+ false_count+=1
+ precision[i][j]= 1 - false_count/(j+1)
if truth_count!=0:
- mAP[i] = mAP[i]/truth_count
+ recall_step = 1/truth_count
+ for j in range(args.neighbors):
+ if probe_label[i] == nneighbors[i][j]:
+ recall[i][j:] += recall_step
+ else:
+ recall[i][:] = 1
+ for i in range(probe_label.shape[0]):
+ for j in range(11):
+ max_level_precision[i][j] = np.max(precision[i][np.where(recall[i]>=(j/10))])
#print(mAP[i])
+ for i in range(probe_label.shape[0]):
+ mAP[i] = sum(max_level_precision[i])/11
print('mAP:',np.mean(mAP))
return target_pred