aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py32
1 files changed, 15 insertions, 17 deletions
diff --git a/train.py b/train.py
index 0b0ce0e..0927943 100755
--- a/train.py
+++ b/train.py
@@ -89,32 +89,30 @@ explained_variances = ()
if args.pca or args.pca_r:
# faces_pca containcts the principial components or the M most variant eigenvectors
average_face = np.mean(faces_train, axis=0)
+ deviations = np.std(faces_train, axis=0)
faces_train = normalise_faces(average_face, faces_train)
faces_test = normalise_faces(average_face, faces_test)
if (args.pca_r):
e_vals, e_vecs = LA.eigh(np.cov(faces_train))
- e_vecs_original = e_vecs
e_vecs = np.dot(faces_train.T, e_vecs)
- # e_vecs = normalise_faces(np.mean(e_vecs,axis=0), e_vecs)
- e_vecs = sc.fit_transform(e_vecs)
- ###TODO Maybe replace with our normalising function
-
- if (args.reconstruct):
- rec_vec = np.divide(average_face, np.std(average_face)).T
- e_vecs_t = e_vecs.T
- for i in range (M):
- rec_vec = np.add(rec_vec, np.dot(e_vecs_original[i][args.reconstruct], e_vecs_t[i]))
- plt.imshow(rec_vec.reshape([46,56]).T, cmap = 'gist_gray')
- plt.show()
else:
e_vals, e_vecs = LA.eigh(np.cov(faces_train.T))
+ # e_vecs = normalise_faces(np.mean(e_vecs,axis=0), e_vecs)
+ # e_vecs = sc.fit_transform(e_vecs)
+
+ e_vals = np.flip(e_vals)[:M]
+ e_vecs = np.fliplr(e_vecs).T[:M]
+ deviations = np.flip(deviations)
- e_vals = np.flip(e_vals)
- e_vecs = np.fliplr(e_vecs).T
- faces_train = np.dot(faces_train, e_vecs[:M].T)
- faces_test = np.dot(faces_test, e_vecs[:M].T)
-#FOR THE ASSESSMENT PRINT EIGENVALUES AND EIGENVECTORS OF BOTH CASES AND COMPARE RESULTS WITH PHYSICAL EXPLAINATIONS
+ faces_train = np.dot(faces_train, e_vecs.T)
+ faces_test = np.dot(faces_test, e_vecs.T)
+ if (args.reconstruct):
+ for face in range(args.reconstruct):
+ rec_vec = np.add(average_face, np.dot(faces_train[face], e_vecs) * deviations)
+ ar = plt.subplot(2, args.reconstruct/2, face + 1)
+ ar.imshow(rec_vec.reshape([46,56]).T, cmap = 'gist_gray')
+ plt.show()
if args.lda or (args.pca and args.lda):
lda = LinearDiscriminantAnalysis(n_components=M)