aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorVasil Zlatanov <v@skozl.com>2018-10-18 18:32:39 +0100
committerVasil Zlatanov <v@skozl.com>2018-10-18 18:32:39 +0100
commitbc9544dee8837c9b2c05d799ada45f4cf0bfc025 (patch)
tree9efa7ea526ac14d646fcdc61faacc754b8eaf734 /train.py
parent9411ec33b05e6d540c58171bb6a30f172d4bef5b (diff)
downloadvz215_np1915-bc9544dee8837c9b2c05d799ada45f4cf0bfc025.tar.gz
vz215_np1915-bc9544dee8837c9b2c05d799ada45f4cf0bfc025.tar.bz2
vz215_np1915-bc9544dee8837c9b2c05d799ada45f4cf0bfc025.zip
Fix plot for large M
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/train.py b/train.py
index b76f5e2..134b70a 100755
--- a/train.py
+++ b/train.py
@@ -42,7 +42,7 @@ M = args.eigen
raw_faces = genfromtxt(args.data, delimiter=',')
targets = np.repeat(np.arange(10),52)
-faces_train, faces_test, target_train, target_test = train_test_split(raw_faces, targets, test_size=0.5, random_state=0)
+faces_train, faces_test, target_train, target_test = train_test_split(raw_faces, targets, test_size=0.2, random_state=0)
# This remove the mean and scales to unit variance
@@ -65,7 +65,7 @@ else:
# Plot the variances (eigenvalues) from the pca object
if args.graph:
- plt.bar(np.arange(M), explained_variances)
+ plt.bar(np.arange(explained_variances.size), explained_variances)
plt.ylabel('Varaiance ratio');plt.xlabel('Face Number')
plt.show()