From 7eae715bbac0a3ace5075e79e1b85b2ceb45c278 Mon Sep 17 00:00:00 2001
From: Vasil Zlatanov <v@skozl.com>
Date: Wed, 27 Feb 2019 19:48:46 +0000
Subject: Also one_hot test

---
 lenet.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/lenet.py b/lenet.py
index 982cd7e..c9ef78c 100644
--- a/lenet.py
+++ b/lenet.py
@@ -20,6 +20,7 @@ def import_mnist():
   X_test, y_test = mnist.test.images, mnist.test.labels
   y_train = keras.utils.to_categorical(y_train, 10)
   y_validation = keras.utils.to_categorical(y_validation, 10)  
+  y_test = keras.utils.to_categorical(y_test, 10)
   return X_train, y_train, X_validation, y_validation, X_test, y_test
 
 def plot_images(images, cls_true, cls_pred=None):
@@ -127,6 +128,6 @@ def test_classifier(model, x_test, y_true):
 
 # If file run directly, perform quick test
 if __name__ == '__main__':
-  x_train, y_train, x_val, y_val, _, _ = import_mnist()
+  x_train, y_train, x_val, y_val, x_t, y_t = import_mnist()
   model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=1)
-  test_classifier(model, x_val, y_val)
+  test_classifier(model, x_t, y_t)
-- 
cgit v1.2.3-70-g09d2