From 9de78f26ac4d575ba8f6be16a67711690d263c16 Mon Sep 17 00:00:00 2001 From: nunzip Date: Tue, 6 Nov 2018 20:32:11 +0000 Subject: Add reconstruction with different M --- train.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) (limited to 'train.py') diff --git a/train.py b/train.py index 69806fc..5367d8b 100755 --- a/train.py +++ b/train.py @@ -115,13 +115,7 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args): rec_vec = np.add(average_face, np.dot(faces_train[args.reconstruct], e_vecs) * deviations_tr) rec_faces_test = np.add(average_face, np.dot(faces_test, e_vecs) * deviations_tst) #THERE MIGHT BE A RECONSTRUCTION PROBLEM DUE TO DEVIATIONS_TST - rec_error = LA.norm(np.subtract(raw_faces_train[args.reconstruct], rec_vec)) - ar = plt.subplot(2, 1, 1) - ar.imshow(rec_vec.reshape([46,56]).T, cmap = 'gist_gray') - ar = plt.subplot(2, 1, 2) - ar.imshow(raw_faces_train[args.reconstruct].reshape([46,56]).T, cmap = 'gist_gray') - plt.show() - + if args.lda: if args.pca_r or (args.pca and M > n_training_faces - n_faces): lda = LinearDiscriminantAnalysis(n_components=M, solver='eigen') @@ -184,7 +178,10 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args): plt.show() #Better n_neighbors = 2 - return draw_conf_mat(args, target_test, target_pred) + if args.reconstruct: + return rec_vec + else: + return draw_conf_mat(args, target_test, target_pred) def main(): parser = argparse.ArgumentParser() @@ -228,6 +225,23 @@ def main(): plt.ylabel('Recognition Accuracy (%)') plt.grid(True) plt.show() + elif args.reconstruct: + M = args.eigen + i = 0 + q = 0 + rec_vecs = np.zeros((5, 2576)) + while M < 400: + rec_vecs[i] = test_model(M, faces_train, faces_test, target_train, target_test, args) + M = M+100 + i = i+1 + while q < i: + ax = plt.subplot(1, i+1, q+1) + ax.imshow(rec_vecs[q].reshape([46, 56]).T, cmap = 'gist_gray') + q = q+1 + ax = plt.subplot(1, i+1, i+1) + ax.imshow(faces_train[args.reconstruct].reshape([46, 56]).T, cmap = 'gist_gray') + plt.show() + else: M = args.eigen start = timer() -- cgit v1.2.3-54-g00ecf