diff options
-rwxr-xr-x | evaluate.py | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/evaluate.py b/evaluate.py index 64f617c..127b3de 100755 --- a/evaluate.py +++ b/evaluate.py @@ -54,7 +54,7 @@ parser.add_argument("-M", "--multrank", help="Run for different ranklist sizes e parser.add_argument("-C", "--comparison", help="Set to 2 to obtain a comparison of baseline and Improved metric", type=int, default=1) parser.add_argument("--data", help="Data folder with features data", default='data') parser.add_argument("-V", "--validation", help="Validation Mode", action='store_true') -parser.add_argument("-K", "--kmean_alt", help="Validation Mode", action='store_true') +parser.add_argument("-K", "--kmean_alt", help="Validation Mode", type=int, default=0) @@ -204,13 +204,12 @@ def main(): train_data, train_label, train_cam = create_kmean_clusters(feature_vectors, labels,gallery_idx,camId) if args.kmean_alt: - kmeans = KMeans(n_clusters=10, random_state=0).fit(train_data) + kmeans = KMeans(n_clusters=args.kmean_alt, random_state=0).fit(train_data) neigh = NearestNeighbors(n_neighbors=1) neigh.fit(kmeans.cluster_centers_) neighbors = neigh.kneighbors(test_data, return_distance=False) target_pred = np.zeros(test_data.shape[0]) - for i in range(10): - print(train_label[np.where(kmeans.labels_==i)].shape) + for i in range(test_data.shape[0]): td = test_data[i].reshape(1,test_data.shape[1]) tc = np.array([test_cam[i]]) |