aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-02-27 20:25:46 +0000
committernunzip <np.scarh@gmail.com>2019-02-27 20:25:46 +0000
commit562e80324a6363b57b191d22ae21517b9d80115b (patch)
treedc13aa3daf65cc09b69360f7e0fa6cf1803723d3
parenta7cb76b4131a7b5b142ac26aa2d47f7e8097c0db (diff)
parente7ac5212b90ac9058070c2d8f3e673cbc193ba08 (diff)
downloade4-gan-562e80324a6363b57b191d22ae21517b9d80115b.tar.gz
e4-gan-562e80324a6363b57b191d22ae21517b9d80115b.tar.bz2
e4-gan-562e80324a6363b57b191d22ae21517b9d80115b.zip
Merge branch 'master' of skozl.com:e4-gan
-rw-r--r--lenet.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/lenet.py b/lenet.py
index 4752fc3..c9ef78c 100644
--- a/lenet.py
+++ b/lenet.py
@@ -20,7 +20,7 @@ def import_mnist():
X_test, y_test = mnist.test.images, mnist.test.labels
y_train = keras.utils.to_categorical(y_train, 10)
y_validation = keras.utils.to_categorical(y_validation, 10)
- y_test = keras.utils.to_categorical(y_test, 10)
+ y_test = keras.utils.to_categorical(y_test, 10)
return X_train, y_train, X_validation, y_validation, X_test, y_test
def plot_images(images, cls_true, cls_pred=None):
@@ -128,6 +128,6 @@ def test_classifier(model, x_test, y_true):
# If file run directly, perform quick test
if __name__ == '__main__':
- x_train, y_train, x_val, y_val, _, _ = import_mnist()
+ x_train, y_train, x_val, y_val, x_t, y_t = import_mnist()
model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=1)
- test_classifier(model, x_val, y_val)
+ test_classifier(model, x_t, y_t)