aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornunzip <np_scarh@e4-pattern-vm.europe-west4-a.c.electric-orbit-223819.internal>2018-12-08 19:11:08 +0000
committernunzip <np_scarh@e4-pattern-vm.europe-west4-a.c.electric-orbit-223819.internal>2018-12-08 19:11:08 +0000
commitdca776c1046d5fb493c89c5f9ed6ccbb4c147a5a (patch)
treef9c51a0f90b8d456116d1091660289fdd080c8b8
parent2dd2d66bbc5a7e4a989c05e8a958dd878705ee20 (diff)
downloadvz215_np1915-dca776c1046d5fb493c89c5f9ed6ccbb4c147a5a.tar.gz
vz215_np1915-dca776c1046d5fb493c89c5f9ed6ccbb4c147a5a.tar.bz2
vz215_np1915-dca776c1046d5fb493c89c5f9ed6ccbb4c147a5a.zip
Add -K size for clusters number
-rwxr-xr-xevaluate.py7
1 files changed, 3 insertions, 4 deletions
diff --git a/evaluate.py b/evaluate.py
index 64f617c..127b3de 100755
--- a/evaluate.py
+++ b/evaluate.py
@@ -54,7 +54,7 @@ parser.add_argument("-M", "--multrank", help="Run for different ranklist sizes e
parser.add_argument("-C", "--comparison", help="Set to 2 to obtain a comparison of baseline and Improved metric", type=int, default=1)
parser.add_argument("--data", help="Data folder with features data", default='data')
parser.add_argument("-V", "--validation", help="Validation Mode", action='store_true')
-parser.add_argument("-K", "--kmean_alt", help="Validation Mode", action='store_true')
+parser.add_argument("-K", "--kmean_alt", help="Validation Mode", type=int, default=0)
@@ -204,13 +204,12 @@ def main():
train_data, train_label, train_cam = create_kmean_clusters(feature_vectors, labels,gallery_idx,camId)
if args.kmean_alt:
- kmeans = KMeans(n_clusters=10, random_state=0).fit(train_data)
+ kmeans = KMeans(n_clusters=args.kmean_alt, random_state=0).fit(train_data)
neigh = NearestNeighbors(n_neighbors=1)
neigh.fit(kmeans.cluster_centers_)
neighbors = neigh.kneighbors(test_data, return_distance=False)
target_pred = np.zeros(test_data.shape[0])
- for i in range(10):
- print(train_label[np.where(kmeans.labels_==i)].shape)
+
for i in range(test_data.shape[0]):
td = test_data[i].reshape(1,test_data.shape[1])
tc = np.array([test_cam[i]])