aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-02-27 17:52:30 +0000
committernunzip <np.scarh@gmail.com>2019-02-27 17:52:30 +0000
commitab1f096a45c2a42a2a3d6e8bf9e60ef0f1ba1b51 (patch)
tree0ab86728ae9ebcfff48ab01dcc0bb0f6f089353a
parent7d053270c97f2030500cee90f8c1a0b8cf1d5f64 (diff)
downloade4-gan-ab1f096a45c2a42a2a3d6e8bf9e60ef0f1ba1b51.tar.gz
e4-gan-ab1f096a45c2a42a2a3d6e8bf9e60ef0f1ba1b51.tar.bz2
e4-gan-ab1f096a45c2a42a2a3d6e8bf9e60ef0f1ba1b51.zip
Set un separate training and testing functions
-rw-r--r--lenet.py142
1 files changed, 95 insertions, 47 deletions
diff --git a/lenet.py b/lenet.py
index f85bc6d..e7756ae 100644
--- a/lenet.py
+++ b/lenet.py
@@ -1,72 +1,120 @@
from __future__ import print_function
import tensorflow.keras as keras
+import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras import backend as K
from tensorflow.keras import optimizers
+import matplotlib.pyplot as plt
+import tensorflow.keras.metrics
+import numpy as np
+import random
-batch_size = 128
-num_classes = 10
+def import_mnist():
+ from tensorflow.examples.tutorials.mnist import input_data
+ mnist = input_data.read_data_sets("MNIST_data/", reshape=False)
+ X_train, y_train = mnist.train.images, mnist.train.labels
+ X_validation, y_validation = mnist.validation.images, mnist.validation.labels
+ X_test, y_test = mnist.test.images, mnist.test.labels
+ X_train = np.pad(X_train, ((0,0),(2,2),(2,2),(0,0)), 'constant')
+ X_validation = np.pad(X_validation, ((0,0),(2,2),(2,2),(0,0)), 'constant')
+ X_test = np.pad(X_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
+
+ return X_train, X_validation, X_test, y_train, y_validation, y_test
+
+def plot_images(images, cls_true, cls_pred=None):
+ assert len(images) == len(cls_true) == 9
+ img_shape = (32, 32)
+ # Create figure with 3x3 sub-plots.
+ fig, axes = plt.subplots(3, 3)
+ fig.subplots_adjust(hspace=0.3, wspace=0.3)
+ for i, ax in enumerate(axes.flat):
+ # Plot image.
+ ax.imshow(images[i].reshape(img_shape), cmap='binary')
+ # Show true and predicted classes.
+ if cls_pred is None:
+ xlabel = "True: {0}".format(cls_true[i])
+ else:
+ xlabel = "True: {0}, Pred: {1}".format(cls_true[i], cls_pred[i])
+ ax.set_xlabel(xlabel)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ plt.show()
-def get_lenet():
+def plot_example_errors(y_pred, y_true, X_test):
+ correct_prediction = np.equal(y_pred, y_true)
+ incorrect = np.equal(correct_prediction, False)
+ images = X_test[incorrect]
+ cls_pred = y_pred[incorrect]
+ cls_true = y_true[incorrect]
+ plot_images(images=images[0:9], cls_true=cls_true[0:9], cls_pred=cls_pred[0:9].astype(np.int))
+
+def get_lenet(shape):
model = keras.Sequential()
-
- model.add(Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=(28,28,1)))
+ model.add(Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=shape)))
model.add(AveragePooling2D())
model.add(Conv2D(filters=16, kernel_size=(3, 3), activation='relu'))
model.add(AveragePooling2D())
-
model.add(Flatten())
model.add(Dense(units=120, activation='relu'))
-
model.add(Dense(units=84, activation='relu'))
-
model.add(Dense(units=10, activation = 'softmax'))
return model
+def plot_history(history, metric = None):
+ # Plots the loss history of training and validation (if existing)
+ # and a given metric
+
+ if metric != None:
+ fig, axes = plt.subplots(2,1)
+ axes[0].plot(history.history[metric])
+ try:
+ axes[0].plot(history.history['val_'+metric])
+ axes[0].legend(['Train', 'Val'])
+ except:
+ pass
+ axes[0].set_title('{:s}'.format(metric))
+ axes[0].set_ylabel('{:s}'.format(metric))
+ axes[0].set_xlabel('Epoch')
+ fig.subplots_adjust(hspace=0.5)
+ axes[1].plot(history.history['loss'])
+ try:
+ axes[1].plot(history.history['val_loss'])
+ axes[1].legend(['Train', 'Val'])
+ except:
+ pass
+ axes[1].set_title('Model Loss')
+ axes[1].set_ylabel('Loss')
+ axes[1].set_xlabel('Epoch')
+ else:
+ plt.plot(history.history['loss'])
+ try:
+ plt.plot(history.history['val_loss'])
+ plt.legend(['Train', 'Val'])
+ except:
+ pass
+ plt.title('Model Loss')
+ plt.ylabel('Loss')
+ plt.xlabel('Epoch')
-# input image dimensions
-img_rows, img_cols = 28, 28
-
-# the data, split between train and test sets
-(x_train, y_train), (x_test, y_test) = mnist.load_data()
-
-if K.image_data_format() == 'channels_first':
- x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
- x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
- input_shape = (1, img_rows, img_cols)
-else:
- x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
- x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
- input_shape = (img_rows, img_cols, 1)
-
-x_train = x_train.astype('float32')
-x_test = x_test.astype('float32')
-x_train /= 255
-x_test /= 255
-print('x_train shape:', x_train.shape)
-print(x_train.shape[0], 'train samples')
-print(x_test.shape[0], 'test samples')
-
-# convert class vectors to binary class matrices
-y_train = keras.utils.to_categorical(y_train, num_classes)
-y_test = keras.utils.to_categorical(y_test, num_classes)
-
-model = get_lenet()
-
-sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
-model.compile(loss='mean_squared_error', optimizer=sgd)
-
-model.fit(x_train, y_train,
- batch_size=batch_size,
- epochs=1,
- verbose=1)
+def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, EPOCHS=100, num_classes=10):
+ y_train = keras.utils.to_categorical(y_train, num_classes)
+ y_val = keras.utils.to_categorical(y_val, num_classes)
+ shape = (32, 32, 1)
+ model = get_lenet(shape)
-y_pred = model.predict(x_test)
+ sgd = optimizers.SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
+ model.compile(loss='categorical_crossentropy', optimizer=sgd)
+
+ history = model.fit(x_train, y_train, batch_size=batch_size, epochs=EPOCHS, verbose=1, validation_data = (x_val, y_val))
+ plot_history(history)
+ return model
-print(y_pred.shape)
-print(y_test.shape)
+def test_classifier(model, x_test, y_true):
+ y_pred = model.predict(x_test)
+ print(metrics.categorical_accuracy(y_true, y_pred))
+ plot_example_errors(y_pred, y_true, x_test)