aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVasil Zlatanov <v@skozl.com>2018-11-07 14:59:31 +0000
committerVasil Zlatanov <v@skozl.com>2018-11-07 14:59:31 +0000
commit8550540ad867a98b945934069aa4ce87f1ecf767 (patch)
tree09d55f7240604ff30b0f4a2e5aaf7274c7148721
parentccd81ab20d829a60211b3922eb787bee1cf59dbe (diff)
downloadvz215_np1915-8550540ad867a98b945934069aa4ce87f1ecf767.tar.gz
vz215_np1915-8550540ad867a98b945934069aa4ce87f1ecf767.tar.bz2
vz215_np1915-8550540ad867a98b945934069aa4ce87f1ecf767.zip
Use reconstruction error in alt method
-rwxr-xr-xtrain.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/train.py b/train.py
index bcf921c..917c895 100755
--- a/train.py
+++ b/train.py
@@ -112,9 +112,10 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args):
faces_test = np.dot(faces_test, e_vecs.T)
distances = np.zeros(faces_test.shape[0])
- for i in range(faces_test.shape[0]):
- norm = LA.norm(faces_train - np.tile(faces_test[i], (faces_train.shape[0], 1)), axis=1)
- distances[i] = np.amin(norm)
+
+
+ rec_vecs = np.add(np.tile(average_face, (faces_test.shape[0], 1)), np.dot(faces_test, e_vecs) * deviations_tr)
+ distances = LA.norm(raw_faces_test - rec_vecs, axis=1);
if args.reconstruct:
rec_vec = np.add(average_face, np.dot(faces_train[args.reconstruct], e_vecs) * deviations_tr)