aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lenet.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/lenet.py b/lenet.py
index 37123bb..982cd7e 100644
--- a/lenet.py
+++ b/lenet.py
@@ -18,9 +18,6 @@ def import_mnist():
X_train, y_train = mnist.train.images, mnist.train.labels
X_validation, y_validation = mnist.validation.images, mnist.validation.labels
X_test, y_test = mnist.test.images, mnist.test.labels
- X_train = np.pad(X_train, ((0,0),(2,2),(2,2),(0,0)), 'constant')
- X_validation = np.pad(X_validation, ((0,0),(2,2),(2,2),(0,0)), 'constant')
- X_test = np.pad(X_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
y_train = keras.utils.to_categorical(y_train, 10)
y_validation = keras.utils.to_categorical(y_validation, 10)
return X_train, y_train, X_validation, y_validation, X_test, y_test
@@ -49,8 +46,6 @@ def plot_example_errors(y_pred, y_true, X_test):
y_true = np.argmax(y_true, axis=1)
correct_prediction = np.equal(y_pred, y_true)
incorrect = np.equal(correct_prediction, False)
- print(correct_prediction.shape)
- print(incorrect[0])
images = X_test[incorrect]
cls_pred = y_pred[incorrect]
cls_true = y_true[incorrect]
@@ -106,11 +101,16 @@ def plot_history(history, metric = None):
plt.ylabel('Loss')
plt.xlabel('Epoch')
-def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, metrics=['categorical_accuracy'], optimizer = None):
+def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, metrics=[categorical_accuracy], optimizer = None):
shape = (32, 32, 1)
+
+ # Pad data to 32x32 (MNIST is 28x28)
+ x_train = np.pad(x_train, ((0,0),(2,2),(2,2),(0,0)), 'constant')
+ x_val = np.pad(x_val, ((0,0),(2,2),(2,2),(0,0)), 'constant')
+
model = get_lenet(shape)
- if optimizer = None:
+ if optimizer == None:
optimizer = optimizers.SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', metrics=metrics, optimizer=optimizer)
@@ -120,6 +120,7 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100,
return model
def test_classifier(model, x_test, y_true):
+ x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
y_pred = model.predict(x_test)
print(categorical_accuracy(y_true, y_pred))
plot_example_errors(y_pred, y_true, x_test)