aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xtrain.py30
1 files changed, 8 insertions, 22 deletions
diff --git a/train.py b/train.py
index 5367d8b..69806fc 100755
--- a/train.py
+++ b/train.py
@@ -115,7 +115,13 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args):
rec_vec = np.add(average_face, np.dot(faces_train[args.reconstruct], e_vecs) * deviations_tr)
rec_faces_test = np.add(average_face, np.dot(faces_test, e_vecs) * deviations_tst)
#THERE MIGHT BE A RECONSTRUCTION PROBLEM DUE TO DEVIATIONS_TST
-
+ rec_error = LA.norm(np.subtract(raw_faces_train[args.reconstruct], rec_vec))
+ ar = plt.subplot(2, 1, 1)
+ ar.imshow(rec_vec.reshape([46,56]).T, cmap = 'gist_gray')
+ ar = plt.subplot(2, 1, 2)
+ ar.imshow(raw_faces_train[args.reconstruct].reshape([46,56]).T, cmap = 'gist_gray')
+ plt.show()
+
if args.lda:
if args.pca_r or (args.pca and M > n_training_faces - n_faces):
lda = LinearDiscriminantAnalysis(n_components=M, solver='eigen')
@@ -178,10 +184,7 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args):
plt.show()
#Better n_neighbors = 2
- if args.reconstruct:
- return rec_vec
- else:
- return draw_conf_mat(args, target_test, target_pred)
+ return draw_conf_mat(args, target_test, target_pred)
def main():
parser = argparse.ArgumentParser()
@@ -225,23 +228,6 @@ def main():
plt.ylabel('Recognition Accuracy (%)')
plt.grid(True)
plt.show()
- elif args.reconstruct:
- M = args.eigen
- i = 0
- q = 0
- rec_vecs = np.zeros((5, 2576))
- while M < 400:
- rec_vecs[i] = test_model(M, faces_train, faces_test, target_train, target_test, args)
- M = M+100
- i = i+1
- while q < i:
- ax = plt.subplot(1, i+1, q+1)
- ax.imshow(rec_vecs[q].reshape([46, 56]).T, cmap = 'gist_gray')
- q = q+1
- ax = plt.subplot(1, i+1, i+1)
- ax.imshow(faces_train[args.reconstruct].reshape([46, 56]).T, cmap = 'gist_gray')
- plt.show()
-
else:
M = args.eigen
start = timer()