aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lenet.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/lenet.py b/lenet.py
index 6bea8dd..3440af7 100644
--- a/lenet.py
+++ b/lenet.py
@@ -11,6 +11,7 @@ from tensorflow.keras.metrics import categorical_accuracy
import numpy as np
import random
from sklearn.metrics import accuracy_score
+from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from classifier_metrics_impl import classifier_score_from_logits
@@ -139,7 +140,7 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100,
model.save_weights('./weights.h5')
return model
-def test_classifier(model, x_test, y_true):
+def test_classifier(model, x_test, y_true, conf_mat=False):
x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
y_pred = model.predict(x_test)
logits = tf.convert_to_tensor(y_pred, dtype=tf.float32)
@@ -147,6 +148,13 @@ def test_classifier(model, x_test, y_true):
y_pred = np.argmax(y_pred, axis=1)
y_true = np.argmax(y_true, axis=1)
plot_example_errors(y_pred, y_true, x_test)
+ cm = confusion_matrix(y_true, y_pred)
+ if conf_mat:
+ plt.matshow(cm, cmap='Blues')
+ plt.colorbar()
+ plt.ylabel('Actual')
+ plt.xlabel('Predicted')
+ plt.show()
return accuracy_score(y_true, y_pred), inception_score
def mix_data(X_train, y_train, X_validation, y_validation, train_gen, tr_labels_gen, val_gen, val_labels_gen, split=0):