diff options
| -rwxr-xr-x | part2.py | 102 | 
1 files changed, 68 insertions, 34 deletions
| @@ -32,18 +32,21 @@ from rerank import re_ranking  parser = argparse.ArgumentParser()  parser.add_argument("-t", "--test", help="Use test data instead of query", action='store_true') -parser.add_argument("-cm", "--conf_mat", help="Show visual confusion matrix", action='store_true') -parser.add_argument("-km", "--kmean", help="Perform Kmeans", action='store_true', default=0) -parser.add_argument("-ma", "--mahalanobis", help="Perform Mahalanobis Distance metric", action='store_true', default=0) +parser.add_argument("-c", "--conf_mat", help="Show visual confusion matrix", action='store_true') +parser.add_argument("-k", "--kmean", help="Perform Kmeans", action='store_true', default=0) +parser.add_argument("-m", "--mahalanobis", help="Perform Mahalanobis Distance metric", action='store_true', default=0)  parser.add_argument("-e", "--euclidean", help="Standard euclidean", action='store_true', default=0)  parser.add_argument("-r", "--rerank", help="Use k-reciprocal rernaking", action='store_true') -parser.add_argument("-ka", "--reranka", help="Parameter 1 for Rerank", type=int, default = 20) -parser.add_argument("-kb", "--rerankb", help="Parameter 2 for rerank", type=int, default = 6) +parser.add_argument("-p", "--reranka", help="Parameter 1 for Rerank", type=int, default = 20) +parser.add_argument("-q", "--rerankb", help="Parameter 2 for rerank", type=int, default = 6) +parser.add_argument("-l", "--rerankl", help="Coefficient to combine distances", type=int, default = 0.3)  parser.add_argument("-n", "--neighbors", help="Number of neighbors", type=int, default = 1)  parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true')  parser.add_argument("-s", "--showrank", help="Save ranklist pic id in a txt file", type=int, default = 0)  parser.add_argument("-2", "--graphspace", help="Graph space", action='store_true', default=0)  parser.add_argument("-1", "--norm", help="Normalized features", action='store_true', default=0) +parser.add_argument("-M", "--multrank", help="Run for different ranklist sizes equal to M", type=int, default=1) +parser.add_argument("-C", "--comparison", help="Set to 2 to obtain a comparison of baseline and Improved metric", type=int, default=1)  args = parser.parse_args() @@ -52,7 +55,6 @@ def verbose(*text):      if args.verbose:          print(text) -#prob query, gal train  def draw_results(test_label, pred_label):      acc_sc = accuracy_score(test_label, pred_label)      cm = confusion_matrix(test_label, pred_label) @@ -63,7 +65,7 @@ def draw_results(test_label, pred_label):          plt.ylabel('Actual')          plt.xlabel('Predicted')          plt.show() -    return +    return acc_sc  def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam, probe_cam, showfiles_train, showfiles_test, args): @@ -81,38 +83,47 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam              distances = np.zeros((probe_label.size, gallery_label.size))              for i in range(probe_label.size):                  print('Mahalanobis step ', i, '/', probe_label.size) -                distances [i] = cdist(probe_data[i].reshape((1,2048)), gallery_data, 'mahalanobis', VI = covmat) +                distances[i] = cdist(probe_data[i].reshape((1,2048)), gallery_data, 'mahalanobis', VI = covmat)          else:              distances = cdist(probe_data, gallery_data, 'euclidean')       ranklist = np.argsort(distances, axis=1) - -    target_pred = np.zeros(ranklist.shape[0]) -    nneighbors = np.zeros((ranklist.shape[0],args.neighbors)) -    nnshowrank = (np.zeros((ranklist.shape[0],args.neighbors))).astype(object) -    for probe_idx in range(probe_data.shape[0]): -        row = ranklist[probe_idx] -        n = 0 -        q = 0 -        while (q < args.neighbors): -            while (probe_cam[probe_idx] == gallery_cam[row[n]] and -              probe_label[probe_idx] == gallery_label[row[n]]): +    test_table = np.arange(1, args.multrank+1) +    target_pred = np.zeros((args.multrank, ranklist.shape[0])) +    nsize = args.neighbors +    if (args.multrank != 1): +        nsize = test_table[args.multrank-1] +    nneighbors = np.zeros((ranklist.shape[0],nsize)) +    nnshowrank = (np.zeros((ranklist.shape[0],nsize))).astype(object) +     +    for i in range(args.multrank): +        args.neighbors = test_table[i] +        for probe_idx in range(probe_data.shape[0]): +            row = ranklist[probe_idx] +            n = 0 +            q = 0 +            while (q < args.neighbors): +                while (probe_cam[probe_idx] == gallery_cam[row[n]] and +                  probe_label[probe_idx] == gallery_label[row[n]]): +                    n += 1 +                nneighbors[probe_idx][q] = gallery_label[row[n]] +                nnshowrank[probe_idx][q] = showfiles_train[row[n]] # +                q += 1                  n += 1 -            nneighbors[probe_idx][q] = gallery_label[row[n]] -            nnshowrank[probe_idx][q] = showfiles_train[row[n]] # -            q += 1 -            n += 1 -        if (args.neighbors) and (probe_label[probe_idx] in nneighbors[probe_idx]): -                target_pred[probe_idx] = probe_label[probe_idx] -        else: -            target_pred[probe_idx] = nneighbors[probe_idx][0] +                 +            if (args.neighbors) and (probe_label[probe_idx] in nneighbors[probe_idx]): +                target_pred[i][probe_idx] = probe_label[probe_idx] +            else: +                target_pred[i][probe_idx] = nneighbors[probe_idx][0] +                  if (args.showrank):               with open("ranklist.txt", "w") as text_file:                  text_file.write(np.array2string(nnshowrank[:args.showrank]))              with open("query.txt", "w") as text_file:                  text_file.write(np.array2string(showfiles_test[:args.showrank])) +                      if args.graphspace:              # Colors for distinct individuals              cols = ['#{:06x}'.format(randint(0, 0xffffff)) for i in range(1467)] @@ -150,6 +161,10 @@ def main():      test_label = labels[query_idx]      train_cam = camId[gallery_idx]      test_cam = camId[query_idx] +     +    accuracy = np.zeros((2, args.multrank)) +    test_table = np.arange(1, args.multrank+1) +      if (args.norm):          train_data = np.divide(train_data,LA.norm(train_data, axis=0))          test_data = np.divide(test_data, LA.norm(test_data, axis=0)) @@ -211,13 +226,32 @@ def main():          cluster = np.array(cl)          clusterlabel = np.array(cllab)          clustercam = np.array(clcam) - -        target_pred = test_model(cluster, test_data, clusterlabel, test_label, clustercam, test_cam, showfiles_train, showfiles_test, args) -        draw_results(test_label, target_pred) -     -    else:     -        target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, args) -        draw_results(test_label, target_pred) +         +        for q in range(args.comparison): +            target_pred = test_model(cluster, test_data, clusterlabel, test_label, clustercam, test_cam, showfiles_train, showfiles_test, args) +            for i in range(args.multrank): +                accuracy[q][i] = draw_results(test_label, target_pred[i])    +            args.rerank = True +            args.neighbors = 1 +         +    else: +        for q in range(args.comparison): +            target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, args) +            for i in range(args.multrank): +                accuracy[q][i] = draw_results(test_label, target_pred[i])    +            args.rerank = True +            args.neighbors = 1 +             +    if(args.multrank != 1): +        plt.plot(test_table[:(args.multrank)], 100*accuracy[0]) +        if(args.comparison!=1): +            plt.plot(test_table[:(args.multrank)], 100*accuracy[1]) +            plt.legend(['Baseline kNN', 'Improved metric'], loc='upper left')     +        plt.xlabel('k rank') +        plt.ylabel('Recognition Accuracy (%)') +        plt.grid(True) +        plt.show() +              if __name__ == "__main__":      main() | 
