aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xevaluate.py21
1 files changed, 17 insertions, 4 deletions
diff --git a/evaluate.py b/evaluate.py
index ca6c8db..9218649 100755
--- a/evaluate.py
+++ b/evaluate.py
@@ -134,20 +134,33 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam
if args.mAP:
precision = np.zeros((probe_label.shape[0], args.neighbors))
+ recall = np.zeros((probe_label.shape[0], args.neighbors))
mAP = np.zeros(probe_label.shape[0])
- truths = precision
+ max_level_precision = np.zeros((probe_label.shape[0],11))
+
for i in range(probe_label.shape[0]):
truth_count=0
false_count=0
for j in range(args.neighbors):
if probe_label[i] == nneighbors[i][j]:
- truths[i][j] = 1
truth_count+=1
precision[i][j] = truth_count/(j+1)
- mAP[i] += precision[i][j]
+ else:
+ false_count+=1
+ precision[i][j]= 1 - false_count/(j+1)
if truth_count!=0:
- mAP[i] = mAP[i]/truth_count
+ recall_step = 1/truth_count
+ for j in range(args.neighbors):
+ if probe_label[i] == nneighbors[i][j]:
+ recall[i][j:] += recall_step
+ else:
+ recall[i][:] = 1
+ for i in range(probe_label.shape[0]):
+ for j in range(11):
+ max_level_precision[i][j] = np.max(precision[i][np.where(recall[i]>=(j/10))])
#print(mAP[i])
+ for i in range(probe_label.shape[0]):
+ mAP[i] = sum(max_level_precision[i])/11
print('mAP:',np.mean(mAP))
return target_pred