diff options
-rwxr-xr-x | part2.py | 4 |
1 files changed, 2 insertions, 2 deletions
@@ -103,8 +103,8 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam nnshowrank[probe_idx][q] = showfiles_train[row[n]] # q += 1 n += 1 - if (args.neighbors): - target_pred[probe_idx] = probe_label[probe_idx] + if (args.neighbors) and (probe_label[probe_idx] in nneighbors[probe_idx]): + target_pred[probe_idx] = probe_label[probe_idx] else: target_pred[probe_idx] = nneighbors[probe_idx][0] |