aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVasil Zlatanov <v@skozl.com>2019-02-27 18:35:41 +0000
committerVasil Zlatanov <v@skozl.com>2019-02-27 18:35:41 +0000
commit53f94e754faabe129075b1c288c3c109376c34e8 (patch)
tree9b310f86d22700f6254ba258409fc0815df3e70b
parent9c26f3b6e6b317c910bf3bdafc9b070c151dff4a (diff)
downloade4-gan-53f94e754faabe129075b1c288c3c109376c34e8.tar.gz
e4-gan-53f94e754faabe129075b1c288c3c109376c34e8.tar.bz2
e4-gan-53f94e754faabe129075b1c288c3c109376c34e8.zip
Add if __main__
-rw-r--r--lenet.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/lenet.py b/lenet.py
index d09f3cc..37123bb 100644
--- a/lenet.py
+++ b/lenet.py
@@ -125,6 +125,7 @@ def test_classifier(model, x_test, y_true):
plot_example_errors(y_pred, y_true, x_test)
# If file run directly, perform quick test
-x_train, y_train, x_val, y_val, _, _ = import_mnist()
-model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=1)
-test_classifier(model, x_val, y_val)
+if __name__ == '__main__':
+ x_train, y_train, x_val, y_val, _, _ = import_mnist()
+ model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=1)
+ test_classifier(model, x_val, y_val)