aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xtrain.py25
1 files changed, 18 insertions, 7 deletions
diff --git a/train.py b/train.py
index a16cb04..86e7337 100755
--- a/train.py
+++ b/train.py
@@ -101,6 +101,7 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args):
if (args.reconstruct):
rec_vec = np.add(average_face, np.dot(faces_train[args.reconstruct], e_vecs) * deviations_tr)
rec_faces_test = np.add(average_face, np.dot(faces_test, e_vecs) * deviations_tst)
+#THERE MIGHT BE A RECONSTRUCTION PROBLEM DUE TO DEVIATIONS_TST
rec_error = LA.norm(np.subtract(raw_faces_train[args.reconstruct], rec_vec))
ar = plt.subplot(2, 1, 1)
ar.imshow(rec_vec.reshape([46,56]).T, cmap = 'gist_gray')
@@ -139,7 +140,7 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args):
plt.show()
if args.grapheigen:
- graph_eigen()
+ #graph_eigen()
# Colors for distinct individuals
cols = ['#{:06x}'.format(randint(0, 0xffffff)) for i in range(n_faces)]
pltCol = [cols[int(k)] for k in target_train]
@@ -150,16 +151,24 @@ def test_model(M, faces_train, faces_test, target_train, target_test, args):
classifier = KNeighborsClassifier(n_neighbors=args.neighbors)
if (args.reconstruct):
- cgassifier.fit(raw_faces_train, target_train)
+ classifier.fit(raw_faces_train, target_train)
target_pred = classifier.predict(rec_faces_test)
#Better Passing n_neighbors = 1
else:
classifier.fit(faces_train, target_train)
target_pred = classifier.predict(faces_test)
- targer_prob = np.max(classifier.predict_proba(faces_test), axis=1).reshape([52,2])
- targer_prob = np.mean(targer_prob, axis=1)
- plt.bar(range(52), targer_prob)
- plt.show()
+ if args.prob:
+ targer_prob = classifier.predict_proba(faces_test)
+ targer_prob_vec = np.zeros(104)
+ for i in range (104):
+ j = int(np.floor(i/2))
+ targer_prob_vec [i] = targer_prob[i][j]
+ avg_targer_prob = np.zeros(52)
+ for i in range (52):
+ avg_targer_prob[i] = (targer_prob_vec[2*i] + targer_prob_vec[2*i + 1])/2
+ #WE CAN FIX THIS BY RESHAPING TARGER_PROB_VEC AND TAKING THE MEAN ON THE RIGHT AXIS
+ plt.bar(range(52), avg_targer_prob)
+ plt.show()
#Better n_neighbors = 2
return draw_conf_mat(args, target_test, target_pred)
@@ -169,7 +178,8 @@ def main():
parser.add_argument("-i", "--data", help="Input CSV file", required=True)
parser.add_argument("-m", "--eigen", help="Number of eigenvalues in model", type=int, default = 10 )
parser.add_argument("-M", "--reigen", help="Number of eigenvalues in model", type=int)
- parser.add_argument("-n", "--neighbors", help="How many neighbors to use", type=int, default = 3)
+ parser.add_argument("-n", "--neighbors", help="How many neighbors to use", type=int, default = 1)
+##USING STANDARD 1 FOR NN ACCURACY
parser.add_argument("-f", "--faces", help="Show faces", type=int, default = 0)
parser.add_argument("-c", "--principal", help="Show principal components", action='store_true')
parser.add_argument("-s", "--seed", help="Seed to use", type=int, default=0)
@@ -183,6 +193,7 @@ def main():
parser.add_argument("-cm", "--conf_mat", help="Show visual confusion matrix", action='store_true')
parser.add_argument("-q", "--pca_r", help="Use Reduced PCA", action='store_true')
+ parser.add_argument("-pr", "--prob", help="Certainty on each guess", action='store_true')
args = parser.parse_args()