diff options
| -rwxr-xr-x | part2.py | 32 | 
1 files changed, 25 insertions, 7 deletions
@@ -39,7 +39,10 @@ parser.add_argument("-e", "--euclidean", help="Standard euclidean", action='stor  parser.add_argument("-r", "--rerank", help="Use k-reciprocal rernaking", action='store_true')  parser.add_argument("-ka", "--reranka", help="Parameter 1 for Rerank", type=int, default = 20)  parser.add_argument("-kb", "--rerankb", help="Parameter 2 for rerank", type=int, default = 6) +parser.add_argument("-n", "--neighbors", help="Number of neighbors", type=int, default = 1)  parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true') +parser.add_argument("-i", "--inrank", help="Checks Accuracy based on presence of label in ranklist", action='store_true', default=0) +  args = parser.parse_args()  def verbose(*text): @@ -59,7 +62,7 @@ def draw_results(test_label, pred_label):          plt.show()      return -def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam, probe_cam): +def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam, probe_cam, args):      # metric = 'jaccard' is also valid      if args.mahalanobis:          metric = 'sqeuclidean' @@ -79,14 +82,29 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam      ranklist = np.argsort(distances, axis=1)      target_pred = np.zeros(ranklist.shape[0]) +    nneighbors = np.zeros((ranklist.shape[0],args.neighbors))      for probe_idx in range(probe_data.shape[0]):          row = ranklist[probe_idx] -        n = 0  -        while (probe_cam[probe_idx] == gallery_cam[row[n]] and -          probe_label[probe_idx] == gallery_label[row[n]]): +        n = 0 +        q = 0 +        while (q < args.neighbors): +            while (probe_cam[probe_idx] == gallery_cam[row[n]] and +              probe_label[probe_idx] == gallery_label[row[n]]): +                n += 1 +            nneighbors[probe_idx][q] = gallery_label[row[n]] +            #NEED TO ADD NNIDX SAME WAY TO PRINT ACTUAL RANKLIST +            q += 1              n += 1 -        target_pred[probe_idx] = gallery_label[row[n]] - +        if (args.inrank): +            if (probe_label[probe_idx] in nneighbors[probe_idx]): +                target_pred[probe_idx] = probe_label[probe_idx] +            else: +                target_pred[probe_idx] = nneighbors[probe_idx][0] +        else: +            target_pred[probe_idx] = nneighbors[probe_idx][0] +            #target_pred[probe_idx] = np.argmax(np.bincount(nneighbors[probe_idx].astype(int))) ###BIN LABELS FROM NN +         +          return target_pred  def main(): @@ -145,7 +163,7 @@ def main():          draw_results(test_label_1, target_pred)      else:     -        target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam) +        target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, args)          draw_results(test_label, target_pred)  if __name__ == "__main__":  | 
