From e58605e30e90bbfcbfd37dcac57e9d97d4c17a85 Mon Sep 17 00:00:00 2001 From: nunzip Date: Fri, 8 Mar 2019 02:03:51 +0000 Subject: Add confusion matrix --- lenet.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/lenet.py b/lenet.py index 6bea8dd..3440af7 100644 --- a/lenet.py +++ b/lenet.py @@ -11,6 +11,7 @@ from tensorflow.keras.metrics import categorical_accuracy import numpy as np import random from sklearn.metrics import accuracy_score +from sklearn.metrics import confusion_matrix from sklearn.model_selection import train_test_split from classifier_metrics_impl import classifier_score_from_logits @@ -139,7 +140,7 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, model.save_weights('./weights.h5') return model -def test_classifier(model, x_test, y_true): +def test_classifier(model, x_test, y_true, conf_mat=False): x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant') y_pred = model.predict(x_test) logits = tf.convert_to_tensor(y_pred, dtype=tf.float32) @@ -147,6 +148,13 @@ def test_classifier(model, x_test, y_true): y_pred = np.argmax(y_pred, axis=1) y_true = np.argmax(y_true, axis=1) plot_example_errors(y_pred, y_true, x_test) + cm = confusion_matrix(y_true, y_pred) + if conf_mat: + plt.matshow(cm, cmap='Blues') + plt.colorbar() + plt.ylabel('Actual') + plt.xlabel('Predicted') + plt.show() return accuracy_score(y_true, y_pred), inception_score def mix_data(X_train, y_train, X_validation, y_validation, train_gen, tr_labels_gen, val_gen, val_labels_gen, split=0): -- cgit v1.2.3-54-g00ecf