aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-x[-rw-r--r--]opt.py59
1 files changed, 41 insertions, 18 deletions
diff --git a/opt.py b/opt.py
index 0f94de4..3d68adc 100644..100755
--- a/opt.py
+++ b/opt.py
@@ -162,7 +162,8 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam
#mAP[i] = sum(max_level_precision[i])/11
mAP[i] = sum(precision[i])/args.neighbors
print('mAP:',np.mean(mAP))
-
+ return np.mean(mAP)
+
return target_pred
def eval(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args):
@@ -235,6 +236,9 @@ def eval(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args)
accuracy[0] = draw_results(test_label, target_pred)
else:
for q in range(args.comparison):
+ if args.mAP:
+ return test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, args)
+
target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, args)
for i in range(args.multrank):
return draw_results(test_label, target_pred[i])
@@ -251,17 +255,7 @@ def eval(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args)
plt.grid(True)
plt.show()
-def main():
- mat = scipy.io.loadmat(os.path.join(args.data,'cuhk03_new_protocol_config_labeled.mat'))
- camId = mat['camId']
- filelist = mat['filelist']
- labels = mat['labels']
- gallery_idx = mat['gallery_idx'] - 1
- query_idx = mat['query_idx'] - 1
- train_idx = mat['train_idx'] - 1
- with open(os.path.join(args.data,'feature_data.json'), 'r') as read_file:
- feature_vectors = np.array(json.load(read_file))
-
+def kopt(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args):
axis = 0
search = 0
steps = 0
@@ -269,12 +263,16 @@ def main():
neg = False
outofaxis = False
start = np.array([1,1])
- args.PCA = 10
+ if args.mAP:
+ args.neighbors = 10
+ args.PCA = 50
args.train = True
args.rerank = True
args.reranka = 1
args.rerankb = 1
opt = np.array([1,1])
+ checktab = np.zeros((100,100))
+ checktab[1][1]=1
max_acc = eval(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args)
print('origin')
print('vertical')
@@ -297,8 +295,10 @@ def main():
p = search*2 + start[0]
args.reranka = p
if not outofaxis:
- print('p:',p,' q:',q)
- acc = eval(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args)
+ if checktab[p][q] == 0:
+ checktab[p][q] = 1
+ print('p:',p,' q:',q)
+ acc = eval(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args)
if acc > max_acc:
print('new p:',p, ' for accuracy:', acc)
max_acc=acc
@@ -318,8 +318,10 @@ def main():
q = search*2 + start[1]
args.rerankb = q
if not outofaxis:
- print('p:',p,' q:',q)
- acc = eval(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args)
+ if checktab[p][q] == 0:
+ checktab[p][q]=1
+ print('p:',p,' q:',q)
+ acc = eval(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args)
if acc > max_acc:
print('new q:',q, ' for accuracy:', acc)
max_acc=acc
@@ -351,8 +353,29 @@ def main():
opt[1] = start[1]
steps=0
vertical=True
- print('Maximum Accuracy:',max_acc,' found at p:',opt[0],'|q:',opt[1])
+ print('Maximum Accuracy:',max_acc,' found at p:',opt[0],'|q:',opt[1])
+ return max_acc, opt
+def main():
+ mat = scipy.io.loadmat(os.path.join(args.data,'cuhk03_new_protocol_config_labeled.mat'))
+ camId = mat['camId']
+ filelist = mat['filelist']
+ labels = mat['labels']
+ gallery_idx = mat['gallery_idx'] - 1
+ query_idx = mat['query_idx'] - 1
+ train_idx = mat['train_idx'] - 1
+ with open(os.path.join(args.data,'feature_data.json'), 'r') as read_file:
+ feature_vectors = np.array(json.load(read_file))
+ l=0
+ max_acc = np.zeros(11)
+ opt = np.zeros((11,2))
+ while l < 11:
+ args.rerankl = l/10
+ print('testing for lambda:',args.rerankl)
+ max_acc[l], opt[l] = kopt(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args)
+ l +=1
+ print('Max accuracy:',np.max(max_acc),' at p:',opt[np.argmax(max_acc)][0], '| q:',opt[np.argmax(max_acc)][1],'| lambda:',np.argmax(max_acc)/10)
+
if __name__ == "__main__":
main()