aboutsummaryrefslogtreecommitdiff
path: root/lenet.py
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-02-27 18:42:46 +0000
committernunzip <np.scarh@gmail.com>2019-02-27 18:42:46 +0000
commite87abd6950701afc7301b15daa86ba60bc083a44 (patch)
tree19dabd10f1871f20a7b082163532e3b97ad24b8b /lenet.py
parentc5f5d81dc0233c03f339fdf932ef3f72871db3cf (diff)
parent53f94e754faabe129075b1c288c3c109376c34e8 (diff)
downloade4-gan-e87abd6950701afc7301b15daa86ba60bc083a44.tar.gz
e4-gan-e87abd6950701afc7301b15daa86ba60bc083a44.tar.bz2
e4-gan-e87abd6950701afc7301b15daa86ba60bc083a44.zip
Merge branch 'master' of skozl.com:e4-gan
Diffstat (limited to 'lenet.py')
-rw-r--r--lenet.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/lenet.py b/lenet.py
index d09f3cc..37123bb 100644
--- a/lenet.py
+++ b/lenet.py
@@ -125,6 +125,7 @@ def test_classifier(model, x_test, y_true):
plot_example_errors(y_pred, y_true, x_test)
# If file run directly, perform quick test
-x_train, y_train, x_val, y_val, _, _ = import_mnist()
-model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=1)
-test_classifier(model, x_val, y_val)
+if __name__ == '__main__':
+ x_train, y_train, x_val, y_val, _, _ = import_mnist()
+ model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=1)
+ test_classifier(model, x_val, y_val)