diff options
author | nunzip <np.scarh@gmail.com> | 2018-11-04 23:44:52 +0000 |
---|---|---|
committer | Vasil Zlatanov <v@skozl.com> | 2018-11-05 15:54:01 +0000 |
commit | dc110263d727ffe4abe6772723208534560355b2 (patch) | |
tree | e8870e65af61f72528b076dda8b6025de13433f1 | |
parent | 09630bb27dec915fbe59385755d62809bcfd689e (diff) | |
download | vz215_np1915-dc110263d727ffe4abe6772723208534560355b2.tar.gz vz215_np1915-dc110263d727ffe4abe6772723208534560355b2.tar.bz2 vz215_np1915-dc110263d727ffe4abe6772723208534560355b2.zip |
Insert scatter matrices rank and confidence of guess(improve)
-rwxr-xr-x | train.py | 44 |
1 files changed, 27 insertions, 17 deletions
@@ -57,7 +57,7 @@ def test_split(n_faces, raw_faces, split, seed): faces_test = faces_test.reshape(n_faces*n_test_faces, n_pixels) return faces_train, faces_test, target_train, target_test -def draw_conf_mat(target_test, target_pred): +def draw_conf_mat(args, target_test, target_pred): cm = confusion_matrix(target_test, target_pred) print(cm) if (args.conf_mat): @@ -66,7 +66,7 @@ def draw_conf_mat(target_test, target_pred): plt.ylabel('Actual') plt.xlabel('Predicted') plt.show() - print('Accuracy %fl' % accuracy_score(target_test, target_pred)) + return accuracy_score(target_test, target_pred) def test_model(M, faces_train, faces_test, target_train, target_test, args): raw_faces_train = faces_train @@ -108,12 +108,18 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args): ar.imshow(raw_faces_train[args.reconstruct].reshape([46,56]).T, cmap = 'gist_gray') plt.show() - if args.lda or (args.pca and args.lda): - lda = LinearDiscriminantAnalysis(n_components=M, solver='eigen') + if args.lda: + if args.pca_r or (args.pca and M > n_training_faces - n_faces): + lda = LinearDiscriminantAnalysis(n_components=M, solver='eigen') + else: + lda = LinearDiscriminantAnalysis(n_components=M, store_covariance='True') + faces_train = lda.fit_transform(faces_train, target_train) faces_test = lda.transform(faces_test) class_means = lda.means_ e_vals = lda.explained_variance_ratio_ + scatter_matrix = lda.covariance_ + print(LA.matrix_rank(scatter_matrix)) if args.faces: if args.lda: @@ -133,7 +139,7 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args): plt.show() if args.grapheigen: - graph_eigen() + graph_eigen() # Colors for distinct individuals cols = ['#{:06x}'.format(randint(0, 0xffffff)) for i in range(n_faces)] pltCol = [cols[int(k)] for k in target_train] @@ -144,14 +150,19 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args): classifier = KNeighborsClassifier(n_neighbors=args.neighbors) if (args.reconstruct): - classifier.fit(raw_faces_train, target_train) + cgassifier.fit(raw_faces_train, target_train) target_pred = classifier.predict(rec_faces_test) #Better Passing n_neighbors = 1 else: classifier.fit(faces_train, target_train) target_pred = classifier.predict(faces_test) + targer_prob = np.max(classifier.predict_proba(faces_test), axis=1).reshape([52,2]) + targer_prob = np.mean(targer_prob, axis=1) + plt.bar(range(52), targer_prob) + plt.show() + #Better n_neighbors = 2 - draw_conf_mat(target_test, target_pred) + return draw_conf_mat(args, target_test, target_pred) def main(): parser = argparse.ArgumentParser() @@ -183,18 +194,17 @@ def main(): if args.reigen: - for M in range(args.eigen, args,reigen): - start = time() - test_model(M, faces_train, faces_test, target_train, target_test, args) - end = time() - print("Run with", M, "eigenvalues completed in %.2f" % end-start, "seconds") + for M in range(args.eigen, args,reigen): + start = timer() + accuracy[M] = test_model(M, faces_train, faces_test, target_train, target_test, args) + end = timer() + print("Run with", M, "eigenvalues completed in ", end-start, "seconds") + else: M = args.eigen - start = time() - test_model(M, faces_train, faces_test, target_train, target_test, args): - end = time() - print("Run with", M, "eigenvalues completed in %.2f" % end-start, "seconds") - + start = timer() + test_model(M, faces_train, faces_test, target_train, target_test, args) + end = timer() if __name__ == "__main__": main() |