From a0b7adfd4a7df4079a4ae73729f9ec0feb6eb8c8 Mon Sep 17 00:00:00 2001 From: nunzip Date: Mon, 3 Dec 2018 17:29:12 +0000 Subject: Fix kmeans --- part2.py | 62 ++++++++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 46 insertions(+), 16 deletions(-) (limited to 'part2.py') diff --git a/part2.py b/part2.py index e9231d8..78cc29d 100755 --- a/part2.py +++ b/part2.py @@ -86,8 +86,6 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam nneighbors = np.zeros((ranklist.shape[0],args.neighbors)) nnshowrank = (np.zeros((ranklist.shape[0],args.neighbors))).astype(object) - print(showfiles_train.shape) - print(showfiles_train.dtype) for probe_idx in range(probe_data.shape[0]): row = ranklist[probe_idx] n = 0 @@ -115,7 +113,6 @@ def test_model(gallery_data, probe_data, gallery_label, probe_label, gallery_cam with open("query.txt", "w") as text_file: text_file.write(np.array2string(showfiles_test[:args.showrank])) - return target_pred def main(): @@ -146,34 +143,67 @@ def main(): test_cam = camId[query_idx] if(args.kmean): - km_labels_1 = np.arange(1,np.max(labels)+1) - km_labels_2 = np.arange(1,np.max(labels)+1) - km_train_data_1 = np.zeros(((km_labels_1.size),(feature_vectors.shape[1]))) - km_train_data_2 = np.zeros(((km_labels_2.size),(feature_vectors.shape[1]))) - km_train_data_1 = KMeans(n_clusters=int(np.max(labels)),random_state=0).fit(train_data_1) - km_train_data_2 = KMeans(n_clusters=int(np.max(labels)),random_state=0).fit(train_data_2) + gallery1 = [] + gallery2 = [] + gallery1lab = [] + gallery2lab = [] + for i in range(gallery_idx.size): + if camId[gallery_idx[i]] == 1: + gallery1.append(feature_vectors[gallery_idx[i]]) + gallery1lab.append(labels[gallery_idx[i]]) + else: + gallery2.append(feature_vectors[gallery_idx[i]]) + gallery2lab.append(labels[gallery_idx[i]]) + + train1 = np.array(gallery1) + train2 = np.array(gallery2) + tlabel1 = np.array(gallery1lab) + tlabel2 = np.array(gallery2lab) + km_train_data_1 = KMeans(n_clusters=int(np.max(labels)),random_state=0).fit(train1) + km_train_data_2 = KMeans(n_clusters=int(np.max(labels)),random_state=0).fit(train2) + + ###REMAP LABELS + km_labels_1 = np.zeros(int(np.max(labels))) # clusters size + km_labels_2 = np.zeros(int(np.max(labels))) km_idx_1 = km_train_data_1.labels_ for i in range(np.max(labels)): class_vote = np.zeros(np.max(labels)) for q in range(km_idx_1.size): if km_idx_1[q]==i: - class_vote[int(train_label_1[q])-1] = class_vote[int(train_label_1[q])-1] + 1 + class_vote[int(tlabel1[q])-1] += 1 km_labels_1[i] = np.argmax(class_vote) + 1 - target_pred = test_model(km_train_data_1.cluster_centers_, test_data_2, km_labels_1, test_label_2) - draw_results(test_label_2, target_pred) - km_idx_2 = km_train_data_2.labels_ for i in range(np.max(labels)): class_vote = np.zeros(np.max(labels)) for q in range(km_idx_2.size): if km_idx_2[q]==i: - class_vote[int(train_label_2[q])-1] = class_vote[int(train_label_2[q])-1] + 1 + class_vote[int(tlabel2[q])-1] += 1 km_labels_2[i] = np.argmax(class_vote) + 1 - target_pred = test_model(km_train_data_2.cluster_centers_, test_data_1, km_labels_2, test_label_1) - draw_results(test_label_1, target_pred) + #MERGE CLUSTERS + cl = [] + cllab = [] + clcam = [] + clustercam1 = np.ones(km_labels_1.size) + clustercam2 = np.add(np.ones(km_labels_2.size), 1) + for i in range(km_labels_1.size): + cl.append(km_train_data_1.cluster_centers_[i]) + cllab.append(km_labels_1[i]) + clcam.append(clustercam1[i]) + for i in range(km_labels_2.size): + cl.append(km_train_data_2.cluster_centers_[i]) + cllab.append(km_labels_2[i]) + clcam.append(clustercam2[i]) + + cluster = np.array(cl) + clusterlabel = np.array(cllab) + clustercam = np.array(clcam) + print(cluster.shape, clusterlabel.shape, clustercam.shape) + + target_pred = test_model(cluster, test_data, clusterlabel, test_label, clustercam, test_cam, showfiles_train, showfiles_test, args) + draw_results(test_label, target_pred) else: target_pred = test_model(train_data, test_data, train_label, test_label, train_cam, test_cam, showfiles_train, showfiles_test, args) -- cgit v1.2.3-54-g00ecf