From 7465d4fdde046843cb8bca3b233c2cdd99c39722 Mon Sep 17 00:00:00 2001 From: Vasil Zlatanov Date: Wed, 27 Feb 2019 18:23:44 +0000 Subject: Fix lenet funcs --- lenet.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) (limited to 'lenet.py') diff --git a/lenet.py b/lenet.py index e7756ae..8595eb8 100644 --- a/lenet.py +++ b/lenet.py @@ -8,7 +8,7 @@ from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D from tensorflow.keras import backend as K from tensorflow.keras import optimizers import matplotlib.pyplot as plt -import tensorflow.keras.metrics +from tensorflow.keras.metrics import categorical_accuracy import numpy as np import random @@ -21,8 +21,9 @@ def import_mnist(): X_train = np.pad(X_train, ((0,0),(2,2),(2,2),(0,0)), 'constant') X_validation = np.pad(X_validation, ((0,0),(2,2),(2,2),(0,0)), 'constant') X_test = np.pad(X_test, ((0,0),(2,2),(2,2),(0,0)), 'constant') - - return X_train, X_validation, X_test, y_train, y_validation, y_test + y_train = keras.utils.to_categorical(y_train, 10) + y_validation = keras.utils.to_categorical(y_validation, 10) + return X_train, y_train, X_validation, y_validation, X_test, y_test def plot_images(images, cls_true, cls_pred=None): assert len(images) == len(cls_true) == 9 @@ -44,8 +45,12 @@ def plot_images(images, cls_true, cls_pred=None): plt.show() def plot_example_errors(y_pred, y_true, X_test): + y_pred = np.argmax(y_pred, axis=1) + y_true = np.argmax(y_true, axis=1) correct_prediction = np.equal(y_pred, y_true) incorrect = np.equal(correct_prediction, False) + print(correct_prediction.shape) + print(incorrect[0]) images = X_test[incorrect] cls_pred = y_pred[incorrect] cls_true = y_true[incorrect] @@ -53,7 +58,7 @@ def plot_example_errors(y_pred, y_true, X_test): def get_lenet(shape): model = keras.Sequential() - model.add(Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=shape))) + model.add(Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=shape)) model.add(AveragePooling2D()) model.add(Conv2D(filters=16, kernel_size=(3, 3), activation='relu')) @@ -101,20 +106,23 @@ def plot_history(history, metric = None): plt.ylabel('Loss') plt.xlabel('Epoch') -def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, EPOCHS=100, num_classes=10): - y_train = keras.utils.to_categorical(y_train, num_classes) - y_val = keras.utils.to_categorical(y_val, num_classes) +def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100): shape = (32, 32, 1) model = get_lenet(shape) sgd = optimizers.SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True) model.compile(loss='categorical_crossentropy', optimizer=sgd) - history = model.fit(x_train, y_train, batch_size=batch_size, epochs=EPOCHS, verbose=1, validation_data = (x_val, y_val)) + history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data = (x_val, y_val)) plot_history(history) return model def test_classifier(model, x_test, y_true): y_pred = model.predict(x_test) - print(metrics.categorical_accuracy(y_true, y_pred)) - plot_example_errors(y_pred, y_true, x_test) + print(categorical_accuracy(y_true, y_pred)) + plot_example_errors(y_pred, y_true, x_test) + +# If file run directly, perform quick test +x_train, y_train, x_val, y_val, _, _ = import_mnist() +model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=1) +test_classifier(model, x_val, y_val) -- cgit v1.2.3-54-g00ecf