aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lenet.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/lenet.py b/lenet.py
index d09f3cc..37123bb 100644
--- a/lenet.py
+++ b/lenet.py
@@ -125,6 +125,7 @@ def test_classifier(model, x_test, y_true):
plot_example_errors(y_pred, y_true, x_test)
# If file run directly, perform quick test
-x_train, y_train, x_val, y_val, _, _ = import_mnist()
-model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=1)
-test_classifier(model, x_val, y_val)
+if __name__ == '__main__':
+ x_train, y_train, x_val, y_val, _, _ = import_mnist()
+ model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=1)
+ test_classifier(model, x_val, y_val)