diff options
author | Vasil Zlatanov <v@skozl.com> | 2018-11-06 21:24:38 +0000 |
---|---|---|
committer | Vasil Zlatanov <v@skozl.com> | 2018-11-06 21:24:38 +0000 |
commit | adf17d6fe8ae651297e604644392a67dd5bc96a0 (patch) | |
tree | 792a37351fe0f42f33162d0f168c6b5a4328c554 | |
parent | e356f1344ca131e98f362a26262844691290c6a4 (diff) | |
download | vz215_np1915-adf17d6fe8ae651297e604644392a67dd5bc96a0.tar.gz vz215_np1915-adf17d6fe8ae651297e604644392a67dd5bc96a0.tar.bz2 vz215_np1915-adf17d6fe8ae651297e604644392a67dd5bc96a0.zip |
Add alternative per class PCA
-rwxr-xr-x | train.py | 93 |
1 files changed, 53 insertions, 40 deletions
@@ -59,15 +59,16 @@ def test_split(n_faces, raw_faces, split, seed): return faces_train, faces_test, target_train, target_test def draw_conf_mat(args, target_test, target_pred): - cm = confusion_matrix(target_test, target_pred) acc_sc = accuracy_score(target_test, target_pred) - print('Accuracy: ', acc_sc) - if (args.conf_mat): - plt.matshow(cm, cmap='Blues') - plt.colorbar() - plt.ylabel('Actual') - plt.xlabel('Predicted') - plt.show() + if not args.classifyalt: + cm = confusion_matrix(target_test, target_pred) + print('Accuracy: ', acc_sc) + if (args.conf_mat): + plt.matshow(cm, cmap='Blues') + plt.colorbar() + plt.ylabel('Actual') + plt.xlabel('Predicted') + plt.show() return acc_sc def test_model(M, faces_train, faces_test, target_train, target_test, args): @@ -87,12 +88,10 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args): faces_train = normalise_faces(average_face, faces_train) faces_test = normalise_faces(average_face, faces_test) if (args.pca_r): - print('Reduced PCA') e_vals, e_vecs = LA.eigh(np.dot(faces_train, faces_train.T)) e_vecs = np.dot(faces_train.T, e_vecs) e_vecs = e_vecs/LA.norm(e_vecs, axis = 0) else: - print('Standard PCA') e_vals, e_vecs = LA.eigh(np.cov(faces_train.T)) # e_vecs = normalise_faces(np.mean(e_vecs,axis=0), e_vecs) #PLOTTING NON-ZERO EVALS @@ -104,18 +103,16 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args): e_vecs = np.fliplr(e_vecs).T[:M] deviations_tr = np.flip(deviations_tr) deviations_tst = np.flip(deviations_tst) - if (args.classifyalt): - faces_train = np.mean(faces_train.reshape([n_faces,8,2576]), axis=1) - target_train = range(n_faces) - raw_faces_train = np.mean(raw_faces_train.reshape([n_faces,8,2576]), axis=1) faces_train = np.dot(faces_train, e_vecs.T) faces_test = np.dot(faces_test, e_vecs.T) - if (args.reconstruct): + distances = np.zeros(faces_test.shape[0]) + for i in range(faces_test.shape[0]): + norm = LA.norm(faces_train - np.tile(faces_test[i], (faces_train.shape[0], 1)), axis=1) + distances[i] = np.amin(norm) + + if args.reconstruct: rec_vec = np.add(average_face, np.dot(faces_train[args.reconstruct], e_vecs) * deviations_tr) - rec_faces_test = np.add(average_face, np.dot(faces_test, e_vecs) * deviations_tst) -#THERE MIGHT BE A RECONSTRUCTION PROBLEM DUE TO DEVIATIONS_TST - rec_error = LA.norm(np.subtract(raw_faces_train[args.reconstruct], rec_vec)) ar = plt.subplot(2, 1, 1) ar.imshow(rec_vec.reshape([46,56]).T, cmap = 'gist_gray') ar = plt.subplot(2, 1, 2) @@ -163,28 +160,23 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args): plt.show() classifier = KNeighborsClassifier(n_neighbors=args.neighbors) - if (args.reconstruct): - classifier.fit(raw_faces_train, target_train) - target_pred = classifier.predict(rec_faces_test) - #Better Passing n_neighbors = 1 - else: - classifier.fit(faces_train, target_train) - target_pred = classifier.predict(faces_test) - if args.prob: - targer_prob = classifier.predict_proba(faces_test) - targer_prob_vec = np.zeros(104) - for i in range (104): - j = int(np.floor(i/2)) - targer_prob_vec [i] = targer_prob[i][j] - avg_targer_prob = np.zeros(n_faces) - for i in range (n_faces): - avg_targer_prob[i] = (targer_prob_vec[2*i] + targer_prob_vec[2*i + 1])/2 + classifier.fit(faces_train, target_train) + target_pred = classifier.predict(faces_test) + if args.prob: + targer_prob = classifier.predict_proba(faces_test) + targer_prob_vec = np.zeros(104) + for i in range (104): + j = int(np.floor(i/2)) + targer_prob_vec [i] = targer_prob[i][j] + avg_targer_prob = np.zeros(n_faces) + for i in range (n_faces): + avg_targer_prob[i] = (targer_prob_vec[2*i] + targer_prob_vec[2*i + 1])/2 #WE CAN FIX THIS BY RESHAPING TARGER_PROB_VEC AND TAKING THE MEAN ON THE RIGHT AXIS - plt.bar(range(n_faces), avg_targer_prob) - plt.show() + plt.bar(range(n_faces), avg_targer_prob) + plt.show() #Better n_neighbors = 2 - return draw_conf_mat(args, target_test, target_pred) + return draw_conf_mat(args, target_test, target_pred), distances def main(): parser = argparse.ArgumentParser() @@ -213,12 +205,33 @@ def main(): targets = np.repeat(np.arange(n_faces),n_cases) faces_train, faces_test, target_train, target_test = test_split(n_faces, raw_faces, args.split, args.seed) - + + if args.classifyalt: + faces_train = faces_train.reshape(n_faces, 8, n_pixels) + target_train = target_train.reshape(n_faces, 8) + + accuracy = np.zeros(n_faces) + distances = np.zeros((n_faces, faces_test.shape[0])) + for i in range(n_faces): + accuracy[i], distances[i] = test_model(args.eigen, faces_train[i], faces_test, target_train[i], target_test, args) + target_pred = np.argmin(distances, axis=0) + acc_sc = accuracy_score(target_test, target_pred) + cm = confusion_matrix(target_test, target_pred) + print('Total Accuracy: ', acc_sc) + if (args.conf_mat): + plt.matshow(cm, cmap='Blues') + plt.colorbar() + plt.ylabel('Actual') + plt.xlabel('Predicted') + plt.show() + return + if args.reigen: - accuracy = np.zeros(args.reigen - args.eigen) + accuracy = np.zeros(n_faces) + rec_error = np.zeros(n_faces) for M in range(args.eigen, args.reigen): start = timer() - accuracy[M - args.eigen] = test_model(M, faces_train, faces_test, target_train, target_test, args) + accuracy[M - args.eigen], rec_error[M - args.eigen] = test_model(M, faces_train, faces_test, target_train, target_test, args) end = timer() print("Run with", M, "eigenvalues completed in ", end-start, "seconds") print("Memory Used:", psutil.Process(os.getpid()).memory_info().rss) |