aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVasil Zlatanov <v@skozl.com>2019-02-27 18:23:44 +0000
committerVasil Zlatanov <v@skozl.com>2019-02-27 18:23:44 +0000
commit7465d4fdde046843cb8bca3b233c2cdd99c39722 (patch)
tree43beef701365cbb249cf4aac04b171665e09fd99
parentab1f096a45c2a42a2a3d6e8bf9e60ef0f1ba1b51 (diff)
downloade4-gan-7465d4fdde046843cb8bca3b233c2cdd99c39722.tar.gz
e4-gan-7465d4fdde046843cb8bca3b233c2cdd99c39722.tar.bz2
e4-gan-7465d4fdde046843cb8bca3b233c2cdd99c39722.zip
Fix lenet funcs
-rw-r--r--lenet.py28
1 files changed, 18 insertions, 10 deletions
diff --git a/lenet.py b/lenet.py
index e7756ae..8595eb8 100644
--- a/lenet.py
+++ b/lenet.py
@@ -8,7 +8,7 @@ from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras import backend as K
from tensorflow.keras import optimizers
import matplotlib.pyplot as plt
-import tensorflow.keras.metrics
+from tensorflow.keras.metrics import categorical_accuracy
import numpy as np
import random
@@ -21,8 +21,9 @@ def import_mnist():
X_train = np.pad(X_train, ((0,0),(2,2),(2,2),(0,0)), 'constant')
X_validation = np.pad(X_validation, ((0,0),(2,2),(2,2),(0,0)), 'constant')
X_test = np.pad(X_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
-
- return X_train, X_validation, X_test, y_train, y_validation, y_test
+ y_train = keras.utils.to_categorical(y_train, 10)
+ y_validation = keras.utils.to_categorical(y_validation, 10)
+ return X_train, y_train, X_validation, y_validation, X_test, y_test
def plot_images(images, cls_true, cls_pred=None):
assert len(images) == len(cls_true) == 9
@@ -44,8 +45,12 @@ def plot_images(images, cls_true, cls_pred=None):
plt.show()
def plot_example_errors(y_pred, y_true, X_test):
+ y_pred = np.argmax(y_pred, axis=1)
+ y_true = np.argmax(y_true, axis=1)
correct_prediction = np.equal(y_pred, y_true)
incorrect = np.equal(correct_prediction, False)
+ print(correct_prediction.shape)
+ print(incorrect[0])
images = X_test[incorrect]
cls_pred = y_pred[incorrect]
cls_true = y_true[incorrect]
@@ -53,7 +58,7 @@ def plot_example_errors(y_pred, y_true, X_test):
def get_lenet(shape):
model = keras.Sequential()
- model.add(Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=shape)))
+ model.add(Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=shape))
model.add(AveragePooling2D())
model.add(Conv2D(filters=16, kernel_size=(3, 3), activation='relu'))
@@ -101,20 +106,23 @@ def plot_history(history, metric = None):
plt.ylabel('Loss')
plt.xlabel('Epoch')
-def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, EPOCHS=100, num_classes=10):
- y_train = keras.utils.to_categorical(y_train, num_classes)
- y_val = keras.utils.to_categorical(y_val, num_classes)
+def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100):
shape = (32, 32, 1)
model = get_lenet(shape)
sgd = optimizers.SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd)
- history = model.fit(x_train, y_train, batch_size=batch_size, epochs=EPOCHS, verbose=1, validation_data = (x_val, y_val))
+ history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data = (x_val, y_val))
plot_history(history)
return model
def test_classifier(model, x_test, y_true):
y_pred = model.predict(x_test)
- print(metrics.categorical_accuracy(y_true, y_pred))
- plot_example_errors(y_pred, y_true, x_test)
+ print(categorical_accuracy(y_true, y_pred))
+ plot_example_errors(y_pred, y_true, x_test)
+
+# If file run directly, perform quick test
+x_train, y_train, x_val, y_val, _, _ = import_mnist()
+model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=1)
+test_classifier(model, x_val, y_val)