diff options
Diffstat (limited to 'lenet.py')
-rw-r--r-- | lenet.py | 15 |
1 files changed, 8 insertions, 7 deletions
@@ -18,9 +18,6 @@ def import_mnist(): X_train, y_train = mnist.train.images, mnist.train.labels X_validation, y_validation = mnist.validation.images, mnist.validation.labels X_test, y_test = mnist.test.images, mnist.test.labels - X_train = np.pad(X_train, ((0,0),(2,2),(2,2),(0,0)), 'constant') - X_validation = np.pad(X_validation, ((0,0),(2,2),(2,2),(0,0)), 'constant') - X_test = np.pad(X_test, ((0,0),(2,2),(2,2),(0,0)), 'constant') y_train = keras.utils.to_categorical(y_train, 10) y_validation = keras.utils.to_categorical(y_validation, 10) return X_train, y_train, X_validation, y_validation, X_test, y_test @@ -49,8 +46,6 @@ def plot_example_errors(y_pred, y_true, X_test): y_true = np.argmax(y_true, axis=1) correct_prediction = np.equal(y_pred, y_true) incorrect = np.equal(correct_prediction, False) - print(correct_prediction.shape) - print(incorrect[0]) images = X_test[incorrect] cls_pred = y_pred[incorrect] cls_true = y_true[incorrect] @@ -106,11 +101,16 @@ def plot_history(history, metric = None): plt.ylabel('Loss') plt.xlabel('Epoch') -def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, metrics=['categorical_accuracy'], optimizer = None): +def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, metrics=[categorical_accuracy], optimizer = None): shape = (32, 32, 1) + + # Pad data to 32x32 (MNIST is 28x28) + x_train = np.pad(x_train, ((0,0),(2,2),(2,2),(0,0)), 'constant') + x_val = np.pad(x_val, ((0,0),(2,2),(2,2),(0,0)), 'constant') + model = get_lenet(shape) - if optimizer = None: + if optimizer == None: optimizer = optimizers.SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True) model.compile(loss='categorical_crossentropy', metrics=metrics, optimizer=optimizer) @@ -120,6 +120,7 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, return model def test_classifier(model, x_test, y_true): + x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant') y_pred = model.predict(x_test) print(categorical_accuracy(y_true, y_pred)) plot_example_errors(y_pred, y_true, x_test) |