diff options
-rw-r--r-- | lenet.py | 10 |
1 files changed, 9 insertions, 1 deletions
@@ -11,6 +11,7 @@ from tensorflow.keras.metrics import categorical_accuracy import numpy as np import random from sklearn.metrics import accuracy_score +from sklearn.metrics import confusion_matrix from sklearn.model_selection import train_test_split from classifier_metrics_impl import classifier_score_from_logits @@ -139,7 +140,7 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, model.save_weights('./weights.h5') return model -def test_classifier(model, x_test, y_true): +def test_classifier(model, x_test, y_true, conf_mat=False): x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant') y_pred = model.predict(x_test) logits = tf.convert_to_tensor(y_pred, dtype=tf.float32) @@ -147,6 +148,13 @@ def test_classifier(model, x_test, y_true): y_pred = np.argmax(y_pred, axis=1) y_true = np.argmax(y_true, axis=1) plot_example_errors(y_pred, y_true, x_test) + cm = confusion_matrix(y_true, y_pred) + if conf_mat: + plt.matshow(cm, cmap='Blues') + plt.colorbar() + plt.ylabel('Actual') + plt.xlabel('Predicted') + plt.show() return accuracy_score(y_true, y_pred), inception_score def mix_data(X_train, y_train, X_validation, y_validation, train_gen, tr_labels_gen, val_gen, val_labels_gen, split=0): |