aboutsummaryrefslogtreecommitdiff
path: root/evaluate.py
diff options
context:
space:
mode:
authornunzip <np_scarh@e4-pattern-vm.europe-west4-a.c.electric-orbit-223819.internal>2018-12-06 15:48:07 +0000
committernunzip <np_scarh@e4-pattern-vm.europe-west4-a.c.electric-orbit-223819.internal>2018-12-06 15:48:07 +0000
commit9347c15c74635345bc4c360be90af9bf5d693b30 (patch)
tree77a500772eca97cca35f343bdfd4b61db591cd7d /evaluate.py
parent4b1ef6bcc45fbe40b9cd9c493e9fa5c1215142c4 (diff)
downloadvz215_np1915-9347c15c74635345bc4c360be90af9bf5d693b30.tar.gz
vz215_np1915-9347c15c74635345bc4c360be90af9bf5d693b30.tar.bz2
vz215_np1915-9347c15c74635345bc4c360be90af9bf5d693b30.zip
Rewrite train mode
Diffstat (limited to 'evaluate.py')
-rwxr-xr-xevaluate.py56
1 files changed, 38 insertions, 18 deletions
diff --git a/evaluate.py b/evaluate.py
index c5528af..6561b81 100755
--- a/evaluate.py
+++ b/evaluate.py
@@ -52,6 +52,8 @@ parser.add_argument("-1", "--normalise", help="Normalized features", action='sto
parser.add_argument("-M", "--multrank", help="Run for different ranklist sizes equal to M", type=int, default=1)
parser.add_argument("-C", "--comparison", help="Set to 2 to obtain a comparison of baseline and Improved metric", type=int, default=1)
parser.add_argument("--data", help="Data folder with features data", default='data')
+parser.add_argument("-V", "--validation", help="Validation Mode", action='store_true')
+
args = parser.parse_args()
@@ -147,22 +149,43 @@ def main():
train_idx = mat['train_idx'] - 1
with open(os.path.join(args.data,'feature_data.json'), 'r') as read_file:
feature_vectors = np.array(json.load(read_file))
+
if args.train:
- query_idx = train_idx.reshape(train_idx.shape[0])
- gallery_idx = train_idx.reshape(train_idx.shape[0])
+ cam = camId[train_idx]
+ cam = cam.reshape((cam.shape[0],1))
+ labs = labels[train_idx].reshape((labels[train_idx].shape[0],1))
+ tt = np.hstack((train_idx, cam))
+ s_train, s_test, s_train_label, s_test_label = train_test_split(tt, labs, test_size=0.3, random_state=0)
+ train, test, train_label, test_label = train_test_split(s_train, s_train_label, test_size=0.3, random_state=0)
+ #to make it smaller we do a double split
+ del labs
+ del cam
+ train_data = feature_vectors[train[:,0]]
+ test_data = feature_vectors[test[:,0]]
+ train_cam = train[:,1]
+ test_cam = test[:,1]
+ showfiles_train = filelist[train[:,0]]
+ showfiles_test = filelist[train[:,0]]
+ del train
+ del test
+ del tt
else:
- query_idx = query_idx.reshape(query_idx.shape[0])
- gallery_idx = gallery_idx.reshape(gallery_idx.shape[0])
- camId = camId.reshape(camId.shape[0])
-
- showfiles_train = filelist[gallery_idx]
- showfiles_test = filelist[query_idx]
- train_data = feature_vectors[gallery_idx]
- test_data = feature_vectors[query_idx]
- train_label = labels[gallery_idx]
- test_label = labels[query_idx]
- train_cam = camId[gallery_idx]
- test_cam = camId[query_idx]
+ if args.validation:
+ query_idx = train_idx.reshape(train_idx.shape[0])
+ gallery_idx = train_idx.reshape(train_idx.shape[0])
+ else:
+ query_idx = query_idx.reshape(query_idx.shape[0])
+ gallery_idx = gallery_idx.reshape(gallery_idx.shape[0])
+ camId = camId.reshape(camId.shape[0])
+
+ showfiles_train = filelist[gallery_idx]
+ showfiles_test = filelist[query_idx]
+ train_data = feature_vectors[gallery_idx]
+ test_data = feature_vectors[query_idx]
+ train_label = labels[gallery_idx]
+ test_label = labels[query_idx]
+ train_cam = camId[gallery_idx]
+ test_cam = camId[query_idx]
accuracy = np.zeros((2, args.multrank))
test_table = np.arange(1, args.multrank+1)
@@ -173,10 +196,7 @@ def main():
test_data = np.divide(test_data, LA.norm(test_data, axis=0))
if(args.kmean):
debug("Using Kmeans")
- train_data, train_label, train_cam = create_kmean_clusters(feature_vectors,
- labels,
- gallery_idx,
- camId)
+ train_data, train_label, train_cam = create_kmean_clusters(feature_vectors, labels,gallery_idx,camId)
for q in range(args.comparison):
target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, args)
for i in range(args.multrank):