aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-02-27 22:44:05 +0000
committernunzip <np.scarh@gmail.com>2019-02-27 22:44:05 +0000
commit508f3feb670506b63b30b6032949913efd8c8ca1 (patch)
tree9a46e18c8c4e8330df19f3a06f6b75a5c9b40eea
parent35216b30bdff05c04fd4846cfc8433b97218139f (diff)
parent3f7889e7823017acb9dcfa53c6c60f899a37a05c (diff)
downloade4-gan-508f3feb670506b63b30b6032949913efd8c8ca1.tar.gz
e4-gan-508f3feb670506b63b30b6032949913efd8c8ca1.tar.bz2
e4-gan-508f3feb670506b63b30b6032949913efd8c8ca1.zip
Merge branch 'master' of skozl.com:e4-gan
-rw-r--r--lenet.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/lenet.py b/lenet.py
index c9ef78c..57ce218 100644
--- a/lenet.py
+++ b/lenet.py
@@ -1,7 +1,6 @@
from __future__ import print_function
import tensorflow.keras as keras
import tensorflow as tf
-from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
@@ -11,6 +10,7 @@ import matplotlib.pyplot as plt
from tensorflow.keras.metrics import categorical_accuracy
import numpy as np
import random
+from sklearn.metrics import accuracy_score
def import_mnist():
from tensorflow.examples.tutorials.mnist import input_data
@@ -43,8 +43,6 @@ def plot_images(images, cls_true, cls_pred=None):
plt.show()
def plot_example_errors(y_pred, y_true, X_test):
- y_pred = np.argmax(y_pred, axis=1)
- y_true = np.argmax(y_true, axis=1)
correct_prediction = np.equal(y_pred, y_true)
incorrect = np.equal(correct_prediction, False)
images = X_test[incorrect]
@@ -123,11 +121,14 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100,
def test_classifier(model, x_test, y_true):
x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
y_pred = model.predict(x_test)
- print(categorical_accuracy(y_true, y_pred))
+ y_pred = np.argmax(y_pred, axis=1)
+ y_true = np.argmax(y_true, axis=1)
+ print("Test acc:", accuracy_score(y_true, y_pred))
plot_example_errors(y_pred, y_true, x_test)
# If file run directly, perform quick test
if __name__ == '__main__':
x_train, y_train, x_val, y_val, x_t, y_t = import_mnist()
+ print(y_t.shape)
model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=1)
test_classifier(model, x_t, y_t)