diff options
author | nunzip <np.scarh@gmail.com> | 2019-02-27 17:52:30 +0000 |
---|---|---|
committer | nunzip <np.scarh@gmail.com> | 2019-02-27 17:52:30 +0000 |
commit | ab1f096a45c2a42a2a3d6e8bf9e60ef0f1ba1b51 (patch) | |
tree | 0ab86728ae9ebcfff48ab01dcc0bb0f6f089353a /lenet.py | |
parent | 7d053270c97f2030500cee90f8c1a0b8cf1d5f64 (diff) | |
download | e4-gan-ab1f096a45c2a42a2a3d6e8bf9e60ef0f1ba1b51.tar.gz e4-gan-ab1f096a45c2a42a2a3d6e8bf9e60ef0f1ba1b51.tar.bz2 e4-gan-ab1f096a45c2a42a2a3d6e8bf9e60ef0f1ba1b51.zip |
Set un separate training and testing functions
Diffstat (limited to 'lenet.py')
-rw-r--r-- | lenet.py | 142 |
1 files changed, 95 insertions, 47 deletions
@@ -1,72 +1,120 @@ from __future__ import print_function import tensorflow.keras as keras +import tensorflow as tf from tensorflow.keras.datasets import mnist from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Dropout, Flatten from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D from tensorflow.keras import backend as K from tensorflow.keras import optimizers +import matplotlib.pyplot as plt +import tensorflow.keras.metrics +import numpy as np +import random -batch_size = 128 -num_classes = 10 +def import_mnist(): + from tensorflow.examples.tutorials.mnist import input_data + mnist = input_data.read_data_sets("MNIST_data/", reshape=False) + X_train, y_train = mnist.train.images, mnist.train.labels + X_validation, y_validation = mnist.validation.images, mnist.validation.labels + X_test, y_test = mnist.test.images, mnist.test.labels + X_train = np.pad(X_train, ((0,0),(2,2),(2,2),(0,0)), 'constant') + X_validation = np.pad(X_validation, ((0,0),(2,2),(2,2),(0,0)), 'constant') + X_test = np.pad(X_test, ((0,0),(2,2),(2,2),(0,0)), 'constant') + + return X_train, X_validation, X_test, y_train, y_validation, y_test + +def plot_images(images, cls_true, cls_pred=None): + assert len(images) == len(cls_true) == 9 + img_shape = (32, 32) + # Create figure with 3x3 sub-plots. + fig, axes = plt.subplots(3, 3) + fig.subplots_adjust(hspace=0.3, wspace=0.3) + for i, ax in enumerate(axes.flat): + # Plot image. + ax.imshow(images[i].reshape(img_shape), cmap='binary') + # Show true and predicted classes. + if cls_pred is None: + xlabel = "True: {0}".format(cls_true[i]) + else: + xlabel = "True: {0}, Pred: {1}".format(cls_true[i], cls_pred[i]) + ax.set_xlabel(xlabel) + ax.set_xticks([]) + ax.set_yticks([]) + plt.show() -def get_lenet(): +def plot_example_errors(y_pred, y_true, X_test): + correct_prediction = np.equal(y_pred, y_true) + incorrect = np.equal(correct_prediction, False) + images = X_test[incorrect] + cls_pred = y_pred[incorrect] + cls_true = y_true[incorrect] + plot_images(images=images[0:9], cls_true=cls_true[0:9], cls_pred=cls_pred[0:9].astype(np.int)) + +def get_lenet(shape): model = keras.Sequential() - - model.add(Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=(28,28,1))) + model.add(Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=shape))) model.add(AveragePooling2D()) model.add(Conv2D(filters=16, kernel_size=(3, 3), activation='relu')) model.add(AveragePooling2D()) - model.add(Flatten()) model.add(Dense(units=120, activation='relu')) - model.add(Dense(units=84, activation='relu')) - model.add(Dense(units=10, activation = 'softmax')) return model +def plot_history(history, metric = None): + # Plots the loss history of training and validation (if existing) + # and a given metric + + if metric != None: + fig, axes = plt.subplots(2,1) + axes[0].plot(history.history[metric]) + try: + axes[0].plot(history.history['val_'+metric]) + axes[0].legend(['Train', 'Val']) + except: + pass + axes[0].set_title('{:s}'.format(metric)) + axes[0].set_ylabel('{:s}'.format(metric)) + axes[0].set_xlabel('Epoch') + fig.subplots_adjust(hspace=0.5) + axes[1].plot(history.history['loss']) + try: + axes[1].plot(history.history['val_loss']) + axes[1].legend(['Train', 'Val']) + except: + pass + axes[1].set_title('Model Loss') + axes[1].set_ylabel('Loss') + axes[1].set_xlabel('Epoch') + else: + plt.plot(history.history['loss']) + try: + plt.plot(history.history['val_loss']) + plt.legend(['Train', 'Val']) + except: + pass + plt.title('Model Loss') + plt.ylabel('Loss') + plt.xlabel('Epoch') -# input image dimensions -img_rows, img_cols = 28, 28 - -# the data, split between train and test sets -(x_train, y_train), (x_test, y_test) = mnist.load_data() - -if K.image_data_format() == 'channels_first': - x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) - x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) - input_shape = (1, img_rows, img_cols) -else: - x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) - x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) - input_shape = (img_rows, img_cols, 1) - -x_train = x_train.astype('float32') -x_test = x_test.astype('float32') -x_train /= 255 -x_test /= 255 -print('x_train shape:', x_train.shape) -print(x_train.shape[0], 'train samples') -print(x_test.shape[0], 'test samples') - -# convert class vectors to binary class matrices -y_train = keras.utils.to_categorical(y_train, num_classes) -y_test = keras.utils.to_categorical(y_test, num_classes) - -model = get_lenet() - -sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) -model.compile(loss='mean_squared_error', optimizer=sgd) - -model.fit(x_train, y_train, - batch_size=batch_size, - epochs=1, - verbose=1) +def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, EPOCHS=100, num_classes=10): + y_train = keras.utils.to_categorical(y_train, num_classes) + y_val = keras.utils.to_categorical(y_val, num_classes) + shape = (32, 32, 1) + model = get_lenet(shape) -y_pred = model.predict(x_test) + sgd = optimizers.SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True) + model.compile(loss='categorical_crossentropy', optimizer=sgd) + + history = model.fit(x_train, y_train, batch_size=batch_size, epochs=EPOCHS, verbose=1, validation_data = (x_val, y_val)) + plot_history(history) + return model -print(y_pred.shape) -print(y_test.shape) +def test_classifier(model, x_test, y_true): + y_pred = model.predict(x_test) + print(metrics.categorical_accuracy(y_true, y_pred)) + plot_example_errors(y_pred, y_true, x_test) |