aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xpart2.py32
1 files changed, 25 insertions, 7 deletions
diff --git a/part2.py b/part2.py
index cdbd412..ad82c37 100755
--- a/part2.py
+++ b/part2.py
@@ -39,7 +39,10 @@ parser.add_argument("-e", "--euclidean", help="Standard euclidean", action='stor
parser.add_argument("-r", "--rerank", help="Use k-reciprocal rernaking", action='store_true')
parser.add_argument("-ka", "--reranka", help="Parameter 1 for Rerank", type=int, default = 20)
parser.add_argument("-kb", "--rerankb", help="Parameter 2 for rerank", type=int, default = 6)
+parser.add_argument("-n", "--neighbors", help="Number of neighbors", type=int, default = 1)
parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true')
+parser.add_argument("-i", "--inrank", help="Checks Accuracy based on presence of label in ranklist", action='store_true', default=0)
+
args = parser.parse_args()
def verbose(*text):
@@ -59,7 +62,7 @@ def draw_results(test_label, pred_label):
plt.show()
return
-def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam, probe_cam):
+def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam, probe_cam, args):
# metric = 'jaccard' is also valid
if args.mahalanobis:
metric = 'sqeuclidean'
@@ -79,14 +82,29 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam
ranklist = np.argsort(distances, axis=1)
target_pred = np.zeros(ranklist.shape[0])
+ nneighbors = np.zeros((ranklist.shape[0],args.neighbors))
for probe_idx in range(probe_data.shape[0]):
row = ranklist[probe_idx]
- n = 0
- while (probe_cam[probe_idx] == gallery_cam[row[n]] and
- probe_label[probe_idx] == gallery_label[row[n]]):
+ n = 0
+ q = 0
+ while (q < args.neighbors):
+ while (probe_cam[probe_idx] == gallery_cam[row[n]] and
+ probe_label[probe_idx] == gallery_label[row[n]]):
+ n += 1
+ nneighbors[probe_idx][q] = gallery_label[row[n]]
+ #NEED TO ADD NNIDX SAME WAY TO PRINT ACTUAL RANKLIST
+ q += 1
n += 1
- target_pred[probe_idx] = gallery_label[row[n]]
-
+ if (args.inrank):
+ if (probe_label[probe_idx] in nneighbors[probe_idx]):
+ target_pred[probe_idx] = probe_label[probe_idx]
+ else:
+ target_pred[probe_idx] = nneighbors[probe_idx][0]
+ else:
+ target_pred[probe_idx] = nneighbors[probe_idx][0]
+ #target_pred[probe_idx] = np.argmax(np.bincount(nneighbors[probe_idx].astype(int))) ###BIN LABELS FROM NN
+
+
return target_pred
def main():
@@ -145,7 +163,7 @@ def main():
draw_results(test_label_1, target_pred)
else:
- target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam)
+ target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, args)
draw_results(test_label, target_pred)
if __name__ == "__main__":