aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xtrain.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/train.py b/train.py
index f09f025..5530664 100755
--- a/train.py
+++ b/train.py
@@ -4,6 +4,7 @@
# EE4 Pattern Recognition coursework
import matplotlib.pyplot as plt
+from mpl_toolkits.mplot3d import Axes3D
import sys
import random
from random import randint
@@ -64,6 +65,7 @@ parser.add_argument("-2", "--grapheigen", help="Swow 2D graph of targets versus
parser.add_argument("-p", "--pca", help="Use PCA", action='store_true')
parser.add_argument("-l", "--lda", help="Use LDA", action='store_true')
parser.add_argument("-r", "--reconstruct", help="Use PCA reconstruction, specify face NR", type=int, default=0)
+parser.add_argument("-cm", "--conf_mat", help="Show visual confusion matrix", action='store_true')
parser.add_argument("-q", "--pca_r", help="Use Reduced PCA", action='store_true')
@@ -123,7 +125,7 @@ if args.pca or args.pca_r:
plt.show()
if args.lda or (args.pca and args.lda):
- lda = LinearDiscriminantAnalysis(n_components=M)
+ lda = LinearDiscriminantAnalysis(n_components=M, solver='eigen')
faces_train = lda.fit_transform(faces_train, target_train)
faces_test = lda.transform(faces_test)
class_means = lda.means_
@@ -150,7 +152,9 @@ if args.grapheigen:
# Colors for distinct individuals
cols = ['#{:06x}'.format(randint(0, 0xffffff)) for i in range(52)]
pltCol = [cols[int(k)] for k in target_train]
- plt.scatter(faces_train[:, 0], faces_train[:, 1], color=pltCol)
+ fig = plt.figure()
+ ax = fig.add_subplot(111, projection='3d')
+ ax.scatter(faces_train[:, 0], faces_train[:, 1], faces_train[:, 2], marker='o', color=pltCol)
plt.show()
classifier = KNeighborsClassifier(n_neighbors=args.neighbors)
@@ -165,4 +169,10 @@ else:
cm = confusion_matrix(target_test, target_pred)
print(cm)
+if (args.conf_mat):
+ plt.matshow(cm, cmap='Blues')
+ plt.colorbar()
+ plt.ylabel('Actual')
+ plt.xlabel('Predicted')
+ plt.show()
print('Accuracy %fl' % accuracy_score(target_test, target_pred))