aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lenet.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/lenet.py b/lenet.py
index 37123bb..93a80cf 100644
--- a/lenet.py
+++ b/lenet.py
@@ -18,9 +18,6 @@ def import_mnist():
X_train, y_train = mnist.train.images, mnist.train.labels
X_validation, y_validation = mnist.validation.images, mnist.validation.labels
X_test, y_test = mnist.test.images, mnist.test.labels
- X_train = np.pad(X_train, ((0,0),(2,2),(2,2),(0,0)), 'constant')
- X_validation = np.pad(X_validation, ((0,0),(2,2),(2,2),(0,0)), 'constant')
- X_test = np.pad(X_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
y_train = keras.utils.to_categorical(y_train, 10)
y_validation = keras.utils.to_categorical(y_validation, 10)
return X_train, y_train, X_validation, y_validation, X_test, y_test
@@ -108,6 +105,11 @@ def plot_history(history, metric = None):
def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, metrics=['categorical_accuracy'], optimizer = None):
shape = (32, 32, 1)
+
+ # Pad data to 32x32 (MNIST is 28x28)
+ x_train = np.pad(x_train, ((0,0),(2,2),(2,2),(0,0)), 'constant')
+ x_val = np.pad(x_val, ((0,0),(2,2),(2,2),(0,0)), 'constant')
+
model = get_lenet(shape)
if optimizer = None:
@@ -120,6 +122,7 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100,
return model
def test_classifier(model, x_test, y_true):
+ x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
y_pred = model.predict(x_test)
print(categorical_accuracy(y_true, y_pred))
plot_example_errors(y_pred, y_true, x_test)