aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorVasil Zlatanov <v@skozl.com>2018-11-07 15:09:15 +0000
committerVasil Zlatanov <v@skozl.com>2018-11-07 15:09:15 +0000
commitcb93af3155c318b9168fdec272d5203a57de3d47 (patch)
treed677e19beadba0e7315e3d0d2a54fc3f269ac8dd /train.py
parent8550540ad867a98b945934069aa4ce87f1ecf767 (diff)
downloadvz215_np1915-cb93af3155c318b9168fdec272d5203a57de3d47.tar.gz
vz215_np1915-cb93af3155c318b9168fdec272d5203a57de3d47.tar.bz2
vz215_np1915-cb93af3155c318b9168fdec272d5203a57de3d47.zip
Replace hardcoded constant
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/train.py b/train.py
index 917c895..6ac963b 100755
--- a/train.py
+++ b/train.py
@@ -213,8 +213,8 @@ def main():
faces_train, faces_test, target_train, target_test = test_split(n_faces, raw_faces, args.split, args.seed)
if args.classifyalt:
- faces_train = faces_train.reshape(n_faces, 8, n_pixels)
- target_train = target_train.reshape(n_faces, 8)
+ faces_train = faces_train.reshape(n_faces, int(faces_train.shape[0]/n_faces), n_pixels)
+ target_train = target_train.reshape(n_faces, int(target_train.shape[0]/n_faces))
accuracy = np.zeros(n_faces)
distances = np.zeros((n_faces, faces_test.shape[0]))