diff options
author | nunzip <np.scarh@gmail.com> | 2018-11-06 20:32:11 +0000 |
---|---|---|
committer | nunzip <np.scarh@gmail.com> | 2018-11-06 20:32:11 +0000 |
commit | 9de78f26ac4d575ba8f6be16a67711690d263c16 (patch) | |
tree | 8bcd120b3ebb98dfdf4de268609dbba42de612fe | |
parent | e356f1344ca131e98f362a26262844691290c6a4 (diff) | |
download | vz215_np1915-9de78f26ac4d575ba8f6be16a67711690d263c16.tar.gz vz215_np1915-9de78f26ac4d575ba8f6be16a67711690d263c16.tar.bz2 vz215_np1915-9de78f26ac4d575ba8f6be16a67711690d263c16.zip |
Add reconstruction with different M
-rwxr-xr-x | train.py | 30 |
1 files changed, 22 insertions, 8 deletions
@@ -115,13 +115,7 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args): rec_vec = np.add(average_face, np.dot(faces_train[args.reconstruct], e_vecs) * deviations_tr) rec_faces_test = np.add(average_face, np.dot(faces_test, e_vecs) * deviations_tst) #THERE MIGHT BE A RECONSTRUCTION PROBLEM DUE TO DEVIATIONS_TST - rec_error = LA.norm(np.subtract(raw_faces_train[args.reconstruct], rec_vec)) - ar = plt.subplot(2, 1, 1) - ar.imshow(rec_vec.reshape([46,56]).T, cmap = 'gist_gray') - ar = plt.subplot(2, 1, 2) - ar.imshow(raw_faces_train[args.reconstruct].reshape([46,56]).T, cmap = 'gist_gray') - plt.show() - + if args.lda: if args.pca_r or (args.pca and M > n_training_faces - n_faces): lda = LinearDiscriminantAnalysis(n_components=M, solver='eigen') @@ -184,7 +178,10 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args): plt.show() #Better n_neighbors = 2 - return draw_conf_mat(args, target_test, target_pred) + if args.reconstruct: + return rec_vec + else: + return draw_conf_mat(args, target_test, target_pred) def main(): parser = argparse.ArgumentParser() @@ -228,6 +225,23 @@ def main(): plt.ylabel('Recognition Accuracy (%)') plt.grid(True) plt.show() + elif args.reconstruct: + M = args.eigen + i = 0 + q = 0 + rec_vecs = np.zeros((5, 2576)) + while M < 400: + rec_vecs[i] = test_model(M, faces_train, faces_test, target_train, target_test, args) + M = M+100 + i = i+1 + while q < i: + ax = plt.subplot(1, i+1, q+1) + ax.imshow(rec_vecs[q].reshape([46, 56]).T, cmap = 'gist_gray') + q = q+1 + ax = plt.subplot(1, i+1, i+1) + ax.imshow(faces_train[args.reconstruct].reshape([46, 56]).T, cmap = 'gist_gray') + plt.show() + else: M = args.eigen start = timer() |