aboutsummaryrefslogtreecommitdiff
path: root/lenet.py
diff options
context:
space:
mode:
authorVasil Zlatanov <v@skozl.com>2019-02-27 20:39:18 +0000
committerVasil Zlatanov <v@skozl.com>2019-02-27 20:39:18 +0000
commitdf27559be9df36d14fe0b9e6bc389fa337b2c491 (patch)
tree1c98d77b4bdc879d7954aeed340caf4e240f772a /lenet.py
parente7ac5212b90ac9058070c2d8f3e673cbc193ba08 (diff)
downloade4-gan-df27559be9df36d14fe0b9e6bc389fa337b2c491.tar.gz
e4-gan-df27559be9df36d14fe0b9e6bc389fa337b2c491.tar.bz2
e4-gan-df27559be9df36d14fe0b9e6bc389fa337b2c491.zip
Update accuracy test
Diffstat (limited to 'lenet.py')
-rw-r--r--lenet.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/lenet.py b/lenet.py
index c9ef78c..57ce218 100644
--- a/lenet.py
+++ b/lenet.py
@@ -1,7 +1,6 @@
from __future__ import print_function
import tensorflow.keras as keras
import tensorflow as tf
-from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
@@ -11,6 +10,7 @@ import matplotlib.pyplot as plt
from tensorflow.keras.metrics import categorical_accuracy
import numpy as np
import random
+from sklearn.metrics import accuracy_score
def import_mnist():
from tensorflow.examples.tutorials.mnist import input_data
@@ -43,8 +43,6 @@ def plot_images(images, cls_true, cls_pred=None):
plt.show()
def plot_example_errors(y_pred, y_true, X_test):
- y_pred = np.argmax(y_pred, axis=1)
- y_true = np.argmax(y_true, axis=1)
correct_prediction = np.equal(y_pred, y_true)
incorrect = np.equal(correct_prediction, False)
images = X_test[incorrect]
@@ -123,11 +121,14 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100,
def test_classifier(model, x_test, y_true):
x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
y_pred = model.predict(x_test)
- print(categorical_accuracy(y_true, y_pred))
+ y_pred = np.argmax(y_pred, axis=1)
+ y_true = np.argmax(y_true, axis=1)
+ print("Test acc:", accuracy_score(y_true, y_pred))
plot_example_errors(y_pred, y_true, x_test)
# If file run directly, perform quick test
if __name__ == '__main__':
x_train, y_train, x_val, y_val, x_t, y_t = import_mnist()
+ print(y_t.shape)
model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=1)
test_classifier(model, x_t, y_t)