From 39ba3506dc3a67c08c20e048237798812a7bd6ea Mon Sep 17 00:00:00 2001 From: nunzip Date: Tue, 11 Dec 2018 19:23:24 +0000 Subject: Add test for lambda --- opt.py | 59 +++++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 18 deletions(-) mode change 100644 => 100755 opt.py (limited to 'opt.py') diff --git a/opt.py b/opt.py old mode 100644 new mode 100755 index 0f94de4..3d68adc --- a/opt.py +++ b/opt.py @@ -162,7 +162,8 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam #mAP[i] = sum(max_level_precision[i])/11 mAP[i] = sum(precision[i])/args.neighbors print('mAP:',np.mean(mAP)) - + return np.mean(mAP) + return target_pred def eval(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args): @@ -235,6 +236,9 @@ def eval(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args) accuracy[0] = draw_results(test_label, target_pred) else: for q in range(args.comparison): + if args.mAP: + return test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, args) + target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, args) for i in range(args.multrank): return draw_results(test_label, target_pred[i]) @@ -251,17 +255,7 @@ def eval(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args) plt.grid(True) plt.show() -def main(): - mat = scipy.io.loadmat(os.path.join(args.data,'cuhk03_new_protocol_config_labeled.mat')) - camId = mat['camId'] - filelist = mat['filelist'] - labels = mat['labels'] - gallery_idx = mat['gallery_idx'] - 1 - query_idx = mat['query_idx'] - 1 - train_idx = mat['train_idx'] - 1 - with open(os.path.join(args.data,'feature_data.json'), 'r') as read_file: - feature_vectors = np.array(json.load(read_file)) - +def kopt(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args): axis = 0 search = 0 steps = 0 @@ -269,12 +263,16 @@ def main(): neg = False outofaxis = False start = np.array([1,1]) - args.PCA = 10 + if args.mAP: + args.neighbors = 10 + args.PCA = 50 args.train = True args.rerank = True args.reranka = 1 args.rerankb = 1 opt = np.array([1,1]) + checktab = np.zeros((100,100)) + checktab[1][1]=1 max_acc = eval(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args) print('origin') print('vertical') @@ -297,8 +295,10 @@ def main(): p = search*2 + start[0] args.reranka = p if not outofaxis: - print('p:',p,' q:',q) - acc = eval(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args) + if checktab[p][q] == 0: + checktab[p][q] = 1 + print('p:',p,' q:',q) + acc = eval(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args) if acc > max_acc: print('new p:',p, ' for accuracy:', acc) max_acc=acc @@ -318,8 +318,10 @@ def main(): q = search*2 + start[1] args.rerankb = q if not outofaxis: - print('p:',p,' q:',q) - acc = eval(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args) + if checktab[p][q] == 0: + checktab[p][q]=1 + print('p:',p,' q:',q) + acc = eval(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args) if acc > max_acc: print('new q:',q, ' for accuracy:', acc) max_acc=acc @@ -351,8 +353,29 @@ def main(): opt[1] = start[1] steps=0 vertical=True - print('Maximum Accuracy:',max_acc,' found at p:',opt[0],'|q:',opt[1]) + print('Maximum Accuracy:',max_acc,' found at p:',opt[0],'|q:',opt[1]) + return max_acc, opt +def main(): + mat = scipy.io.loadmat(os.path.join(args.data,'cuhk03_new_protocol_config_labeled.mat')) + camId = mat['camId'] + filelist = mat['filelist'] + labels = mat['labels'] + gallery_idx = mat['gallery_idx'] - 1 + query_idx = mat['query_idx'] - 1 + train_idx = mat['train_idx'] - 1 + with open(os.path.join(args.data,'feature_data.json'), 'r') as read_file: + feature_vectors = np.array(json.load(read_file)) + l=0 + max_acc = np.zeros(11) + opt = np.zeros((11,2)) + while l < 11: + args.rerankl = l/10 + print('testing for lambda:',args.rerankl) + max_acc[l], opt[l] = kopt(camId, filelist, labels, gallery_idx, train_idx, feature_vectors, args) + l +=1 + print('Max accuracy:',np.max(max_acc),' at p:',opt[np.argmax(max_acc)][0], '| q:',opt[np.argmax(max_acc)][1],'| lambda:',np.argmax(max_acc)/10) + if __name__ == "__main__": main() -- cgit v1.2.3-54-g00ecf