aboutsummaryrefslogtreecommitdiff
path: root/train.py
blob: b7e9fb036bbe08b794709f13cc88973fe77dc389 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#!/usr/bin/env python
# Train a model from sample data
# Author: Vasil Zlatanov, Nunzio Pucci
# EE4 Pattern Recognition coursework

import argparse
import numpy as np

from numpy import genfromtxt
from numpy import linalg as LA

# subtract the normal face from each row of the face matrix
def normalise_faces(average_face, raw_faces):
    return np.subtract(raw_faces, np.tile(average_face, (raw_faces.shape[1],1)).T)


# usage: train.py [-h] -i DATA -o MODEL [-m M]
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--data", help="Input CSV file", required=True)
parser.add_argument("-o", "--model", help="Output model file", required=True)
parser.add_argument("-m", "--M", help="Number of eigenvalues in model", type=int)
args = parser.parse_args()

assert args.data, "No input CSV data (-i, --input-data)"
assert args.model, "No model specified (-o, --model)"

M = args.M | -1;

raw_faces = genfromtxt(args.data, delimiter=',').T

average_face = np.average(raw_faces, axis=1)
normal_faces = normalise_faces(average_face, raw_faces)

e_vals, e_vecs = LA.eig(np.cov(normal_faces))

np.savez(args.model, e_vals=e_vals[:M], e_vecs=e_vecs[:M])