From 93f09de3f1ba422758ce679ce99ad0ac17d8d4a5 Mon Sep 17 00:00:00 2001 From: nunzip Date: Tue, 30 Oct 2018 18:41:25 +0000 Subject: Improve Scatterplots(3D) and Confusion Matrix --- train.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index f09f025..5530664 100755 --- a/train.py +++ b/train.py @@ -4,6 +4,7 @@ # EE4 Pattern Recognition coursework import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D import sys import random from random import randint @@ -64,6 +65,7 @@ parser.add_argument("-2", "--grapheigen", help="Swow 2D graph of targets versus parser.add_argument("-p", "--pca", help="Use PCA", action='store_true') parser.add_argument("-l", "--lda", help="Use LDA", action='store_true') parser.add_argument("-r", "--reconstruct", help="Use PCA reconstruction, specify face NR", type=int, default=0) +parser.add_argument("-cm", "--conf_mat", help="Show visual confusion matrix", action='store_true') parser.add_argument("-q", "--pca_r", help="Use Reduced PCA", action='store_true') @@ -123,7 +125,7 @@ if args.pca or args.pca_r: plt.show() if args.lda or (args.pca and args.lda): - lda = LinearDiscriminantAnalysis(n_components=M) + lda = LinearDiscriminantAnalysis(n_components=M, solver='eigen') faces_train = lda.fit_transform(faces_train, target_train) faces_test = lda.transform(faces_test) class_means = lda.means_ @@ -150,7 +152,9 @@ if args.grapheigen: # Colors for distinct individuals cols = ['#{:06x}'.format(randint(0, 0xffffff)) for i in range(52)] pltCol = [cols[int(k)] for k in target_train] - plt.scatter(faces_train[:, 0], faces_train[:, 1], color=pltCol) + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + ax.scatter(faces_train[:, 0], faces_train[:, 1], faces_train[:, 2], marker='o', color=pltCol) plt.show() classifier = KNeighborsClassifier(n_neighbors=args.neighbors) @@ -165,4 +169,10 @@ else: cm = confusion_matrix(target_test, target_pred) print(cm) +if (args.conf_mat): + plt.matshow(cm, cmap='Blues') + plt.colorbar() + plt.ylabel('Actual') + plt.xlabel('Predicted') + plt.show() print('Accuracy %fl' % accuracy_score(target_test, target_pred)) -- cgit v1.2.3-54-g00ecf