aboutsummaryrefslogtreecommitdiff
path: root/part2.py
diff options
context:
space:
mode:
authornunzip <np_scarh@e4-pattern-vm.europe-west4-a.c.electric-orbit-223819.internal>2018-12-03 15:26:10 +0000
committernunzip <np_scarh@e4-pattern-vm.europe-west4-a.c.electric-orbit-223819.internal>2018-12-03 15:26:10 +0000
commit77bb4328c094083f9647ea75eea6f16d70443749 (patch)
tree7c6f1d5a90f4ef1e397cf9cd07b53ce462137adc /part2.py
parent1def9d7690e29fa635f8b0022dfb41557fa1219d (diff)
downloadvz215_np1915-77bb4328c094083f9647ea75eea6f16d70443749.tar.gz
vz215_np1915-77bb4328c094083f9647ea75eea6f16d70443749.tar.bz2
vz215_np1915-77bb4328c094083f9647ea75eea6f16d70443749.zip
Dispay picture name from ranklist
Diffstat (limited to 'part2.py')
-rwxr-xr-xpart2.py23
1 files changed, 18 insertions, 5 deletions
diff --git a/part2.py b/part2.py
index ad82c37..e9231d8 100755
--- a/part2.py
+++ b/part2.py
@@ -42,6 +42,7 @@ parser.add_argument("-kb", "--rerankb", help="Parameter 2 for rerank", type=int,
parser.add_argument("-n", "--neighbors", help="Number of neighbors", type=int, default = 1)
parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true')
parser.add_argument("-i", "--inrank", help="Checks Accuracy based on presence of label in ranklist", action='store_true', default=0)
+parser.add_argument("-s", "--showrank", help="Save ranklist pic id in a txt file", type=int, default = 0)
args = parser.parse_args()
@@ -62,7 +63,7 @@ def draw_results(test_label, pred_label):
plt.show()
return
-def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam, probe_cam, args):
+def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam, probe_cam, showfiles_train, showfiles_test, args):
# metric = 'jaccard' is also valid
if args.mahalanobis:
metric = 'sqeuclidean'
@@ -83,6 +84,10 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam
target_pred = np.zeros(ranklist.shape[0])
nneighbors = np.zeros((ranklist.shape[0],args.neighbors))
+ nnshowrank = (np.zeros((ranklist.shape[0],args.neighbors))).astype(object)
+
+ print(showfiles_train.shape)
+ print(showfiles_train.dtype)
for probe_idx in range(probe_data.shape[0]):
row = ranklist[probe_idx]
n = 0
@@ -92,7 +97,7 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam
probe_label[probe_idx] == gallery_label[row[n]]):
n += 1
nneighbors[probe_idx][q] = gallery_label[row[n]]
- #NEED TO ADD NNIDX SAME WAY TO PRINT ACTUAL RANKLIST
+ nnshowrank[probe_idx][q] = showfiles_train[row[n]] #
q += 1
n += 1
if (args.inrank):
@@ -100,9 +105,15 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam
target_pred[probe_idx] = probe_label[probe_idx]
else:
target_pred[probe_idx] = nneighbors[probe_idx][0]
+
else:
target_pred[probe_idx] = nneighbors[probe_idx][0]
- #target_pred[probe_idx] = np.argmax(np.bincount(nneighbors[probe_idx].astype(int))) ###BIN LABELS FROM NN
+
+ if (args.showrank):
+ with open("ranklist.txt", "w") as text_file:
+ text_file.write(np.array2string(nnshowrank[:args.showrank]))
+ with open("query.txt", "w") as text_file:
+ text_file.write(np.array2string(showfiles_test[:args.showrank]))
return target_pred
@@ -124,7 +135,9 @@ def main():
else:
query_idx = query_idx.reshape(query_idx.shape[0])
camId = camId.reshape(camId.shape[0])
-
+
+ showfiles_train = filelist[gallery_idx]
+ showfiles_test = filelist[query_idx]
train_data = feature_vectors[gallery_idx]
test_data = feature_vectors[query_idx]
train_label = labels[gallery_idx]
@@ -163,7 +176,7 @@ def main():
draw_results(test_label_1, target_pred)
else:
- target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, args)
+ target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, args)
draw_results(test_label, target_pred)
if __name__ == "__main__":