aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorVasil Zlatanov <v@skozl.com>2018-10-18 12:45:55 +0100
committerVasil Zlatanov <v@skozl.com>2018-10-18 12:45:55 +0100
commit1bd5993aeb5cb84657cd1acb72f9b7c6e6c3553d (patch)
treec84bf0bcc612faa95bb92c0120693edd1f72b414 /train.py
parent6adf7035aff908622ff4622de5669b3ad6461c1e (diff)
downloadvz215_np1915-1bd5993aeb5cb84657cd1acb72f9b7c6e6c3553d.tar.gz
vz215_np1915-1bd5993aeb5cb84657cd1acb72f9b7c6e6c3553d.tar.bz2
vz215_np1915-1bd5993aeb5cb84657cd1acb72f9b7c6e6c3553d.zip
Add training script with variable M
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/train.py b/train.py
new file mode 100755
index 0000000..b7e9fb0
--- /dev/null
+++ b/train.py
@@ -0,0 +1,36 @@
+#!/usr/bin/env python
+# Train a model from sample data
+# Author: Vasil Zlatanov, Nunzio Pucci
+# EE4 Pattern Recognition coursework
+
+import argparse
+import numpy as np
+
+from numpy import genfromtxt
+from numpy import linalg as LA
+
+# subtract the normal face from each row of the face matrix
+def normalise_faces(average_face, raw_faces):
+ return np.subtract(raw_faces, np.tile(average_face, (raw_faces.shape[1],1)).T)
+
+
+# usage: train.py [-h] -i DATA -o MODEL [-m M]
+parser = argparse.ArgumentParser()
+parser.add_argument("-i", "--data", help="Input CSV file", required=True)
+parser.add_argument("-o", "--model", help="Output model file", required=True)
+parser.add_argument("-m", "--M", help="Number of eigenvalues in model", type=int)
+args = parser.parse_args()
+
+assert args.data, "No input CSV data (-i, --input-data)"
+assert args.model, "No model specified (-o, --model)"
+
+M = args.M | -1;
+
+raw_faces = genfromtxt(args.data, delimiter=',').T
+
+average_face = np.average(raw_faces, axis=1)
+normal_faces = normalise_faces(average_face, raw_faces)
+
+e_vals, e_vecs = LA.eig(np.cov(normal_faces))
+
+np.savez(args.model, e_vals=e_vals[:M], e_vecs=e_vecs[:M])