aboutsummaryrefslogtreecommitdiff
path: root/evaluate.py
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2018-12-12 23:43:10 +0000
committernunzip <np.scarh@gmail.com>2018-12-12 23:43:10 +0000
commit9631be5a9b9e90f74b3484632ab5d9f379334a50 (patch)
tree03fc38207617242a3fe6e4a0c5dc4520a38e3f2b /evaluate.py
parent7167c720b525b86d689775c58b0be4ec92c8fc4d (diff)
downloadvz215_np1915-9631be5a9b9e90f74b3484632ab5d9f379334a50.tar.gz
vz215_np1915-9631be5a9b9e90f74b3484632ab5d9f379334a50.tar.bz2
vz215_np1915-9631be5a9b9e90f74b3484632ab5d9f379334a50.zip
Import test_model and draw_results from evaluate.
Rewrite mAP return in test_model
Diffstat (limited to 'evaluate.py')
-rwxr-xr-xevaluate.py21
1 files changed, 14 insertions, 7 deletions
diff --git a/evaluate.py b/evaluate.py
index 99e5eed..b178abc 100755
--- a/evaluate.py
+++ b/evaluate.py
@@ -128,7 +128,7 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam
if args.mAP:
precision = np.zeros((probe_label.shape[0], args.neighbors))
recall = np.zeros((probe_label.shape[0], args.neighbors))
- mAP = np.zeros(probe_label.shape[0])
+ AP = np.zeros(probe_label.shape[0])
max_level_precision = np.zeros((probe_label.shape[0],11))
for i in range(probe_label.shape[0]):
@@ -152,9 +152,10 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam
for j in range(11):
max_level_precision[i][j] = np.max(precision[i][np.where(recall[i]>=(j/10))])
for i in range(probe_label.shape[0]):
- mAP[i] = sum(max_level_precision[i])/11
- print('mAP:',np.mean(mAP))
-
+ AP[i] = sum(max_level_precision[i])/11
+ mAP = np.mean(AP)
+ print('mAP:',mAP)
+ return target_pred, mAP
return target_pred
def main():
@@ -231,12 +232,18 @@ def main():
td = test_data[i].reshape(1,test_data.shape[1])
tc = np.array([test_cam[i]])
tl = np.array([test_label[i]])
- target_pred[i] = (test_model(train_data[np.where(kmeans.labels_==neighbors[i])], td, train_label[np.where(kmeans.labels_==neighbors[i])], tl, train_cam[np.where(kmeans.labels_==neighbors[i])], tc, showfiles_train[np.where(kmeans.labels_==neighbors[i])], showfiles_test[i], train_model, args))
-
+ if args.mAP:
+ target_pred[i], mAP = (test_model(train_data[np.where(kmeans.labels_==neighbors[i])], td, train_label[np.where(kmeans.labels_==neighbors[i])], tl, train_cam[np.where(kmeans.labels_==neighbors[i])], tc, showfiles_train[np.where(kmeans.labels_==neighbors[i])], showfiles_test[i], train_model, args))
+ else:
+ target_pred[i] = (test_model(train_data[np.where(kmeans.labels_==neighbors[i])], td, train_label[np.where(kmeans.labels_==neighbors[i])], tl, train_cam[np.where(kmeans.labels_==neighbors[i])], tc, showfiles_train[np.where(kmeans.labels_==neighbors[i])], showfiles_test[i], train_model, args))
+
accuracy[0] = draw_results(test_label, target_pred)
else:
for q in range(args.comparison+1):
- target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, train_model, args)
+ if args.mAP:
+ target_pred, mAP = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, train_model, args)
+ else:
+ target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, train_model, args)
for i in range(args.multrank):
accuracy[q][i] = draw_results(test_label, target_pred[i])
args.rerank = True