aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py30
1 files changed, 22 insertions, 8 deletions
diff --git a/train.py b/train.py
index 69806fc..5367d8b 100755
--- a/train.py
+++ b/train.py
@@ -115,13 +115,7 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args):
rec_vec = np.add(average_face, np.dot(faces_train[args.reconstruct], e_vecs) * deviations_tr)
rec_faces_test = np.add(average_face, np.dot(faces_test, e_vecs) * deviations_tst)
#THERE MIGHT BE A RECONSTRUCTION PROBLEM DUE TO DEVIATIONS_TST
- rec_error = LA.norm(np.subtract(raw_faces_train[args.reconstruct], rec_vec))
- ar = plt.subplot(2, 1, 1)
- ar.imshow(rec_vec.reshape([46,56]).T, cmap = 'gist_gray')
- ar = plt.subplot(2, 1, 2)
- ar.imshow(raw_faces_train[args.reconstruct].reshape([46,56]).T, cmap = 'gist_gray')
- plt.show()
-
+
if args.lda:
if args.pca_r or (args.pca and M > n_training_faces - n_faces):
lda = LinearDiscriminantAnalysis(n_components=M, solver='eigen')
@@ -184,7 +178,10 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args):
plt.show()
#Better n_neighbors = 2
- return draw_conf_mat(args, target_test, target_pred)
+ if args.reconstruct:
+ return rec_vec
+ else:
+ return draw_conf_mat(args, target_test, target_pred)
def main():
parser = argparse.ArgumentParser()
@@ -228,6 +225,23 @@ def main():
plt.ylabel('Recognition Accuracy (%)')
plt.grid(True)
plt.show()
+ elif args.reconstruct:
+ M = args.eigen
+ i = 0
+ q = 0
+ rec_vecs = np.zeros((5, 2576))
+ while M < 400:
+ rec_vecs[i] = test_model(M, faces_train, faces_test, target_train, target_test, args)
+ M = M+100
+ i = i+1
+ while q < i:
+ ax = plt.subplot(1, i+1, q+1)
+ ax.imshow(rec_vecs[q].reshape([46, 56]).T, cmap = 'gist_gray')
+ q = q+1
+ ax = plt.subplot(1, i+1, i+1)
+ ax.imshow(faces_train[args.reconstruct].reshape([46, 56]).T, cmap = 'gist_gray')
+ plt.show()
+
else:
M = args.eigen
start = timer()