aboutsummaryrefslogtreecommitdiff
path: root/lenet.py
diff options
context:
space:
mode:
Diffstat (limited to 'lenet.py')
-rw-r--r--lenet.py58
1 files changed, 49 insertions, 9 deletions
diff --git a/lenet.py b/lenet.py
index 4950fe9..a94259b 100644
--- a/lenet.py
+++ b/lenet.py
@@ -11,9 +11,13 @@ from tensorflow.keras.metrics import categorical_accuracy
import numpy as np
import random
from sklearn.metrics import accuracy_score
+from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
-
+from sklearn.decomposition import PCA
from classifier_metrics_impl import classifier_score_from_logits
+from sklearn.utils import shuffle
+from sklearn.manifold import TSNE
+import scikitplot as skplt
def import_mnist():
from tensorflow.examples.tutorials.mnist import input_data
@@ -64,8 +68,7 @@ def get_lenet(shape):
model.add(Dense(units=120, activation='relu'))
model.add(Dense(units=84, activation='relu'))
- #model.add(Dense(units=10, activation = 'softmax'))
- model.add(Dense(units=10, activation = 'relu'))
+ model.add(Dense(units=10, activation = 'softmax'))
return model
def get_lenet_icp(shape):
@@ -140,14 +143,51 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100,
model.save_weights('./weights.h5')
return model
-def test_classifier(model, x_test, y_true):
+def plot_probas(model, x_test, y_true):
+ y_true = np.argmax(y_true, axis=1)
+ x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
+ probas = model.predict(x_test)
+ skplt.metrics.plot_roc(y_true, probas)
+ plt.show()
+ skplt.metrics.plot_precision_recall_curve(y_true, probas)
+ plt.show()
+
+def test_classifier(model, x_test, y_true, conf_mat=False, pca=False, tsne=False):
x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
- y_pred = model.predict(x_test)
- logits = tf.convert_to_tensor(y_pred, dtype=tf.float32)
- inception_score = tf.keras.backend.eval(classifier_score_from_logits(logits))
- y_pred = np.argmax(y_pred, axis=1)
+ logits = model.predict(x_test)
+ tf_logits = tf.convert_to_tensor(logits, dtype=tf.float32)
+ inception_score = tf.keras.backend.eval(classifier_score_from_logits(tf_logits))
+ y_pred = np.argmax(logits, axis=1)
y_true = np.argmax(y_true, axis=1)
plot_example_errors(y_pred, y_true, x_test)
+ cm = confusion_matrix(y_true, y_pred)
+ if conf_mat:
+ plt.matshow(cm, cmap='Blues')
+ plt.colorbar()
+ plt.ylabel('Actual')
+ plt.xlabel('Predicted')
+ plt.show()
+ if pca:
+ set_pca = PCA(n_components=2)
+ pca_rep = set_pca.fit_transform(logits)
+ pca_rep, y_tmp = shuffle(pca_rep, y_true, random_state=0)
+ plt.scatter(pca_rep[:5000, 0], pca_rep[:5000, 1], c=y_tmp[:5000], edgecolor='none', alpha=0.5, cmap=plt.cm.get_cmap('Paired', 10))
+ plt.xlabel('Feature 1')
+ plt.ylabel('Feature 2')
+ plt.colorbar();
+ plt.show()
+ if tsne:
+ tsne = TSNE(n_components=2, random_state=0)
+ components = tsne.fit_transform(logits)
+ print(components.shape)
+ components, y_tmp = shuffle(components, y_true, random_state=0)
+ plt.scatter(components[:5000, 0], components[:5000, 1], c=y_tmp[:5000], edgecolor='none', alpha=0.5, cmap=plt.cm.get_cmap('Paired', 10))
+ plt.xlabel('Feature 1')
+ plt.ylabel('Feature 2')
+ plt.colorbar();
+ plt.show()
+
+
return accuracy_score(y_true, y_pred), inception_score
def mix_data(X_train, y_train, X_validation, y_validation, train_gen, tr_labels_gen, val_gen, val_labels_gen, split=0):
@@ -181,4 +221,4 @@ if __name__ == '__main__':
x_train, y_train, x_val, y_val, x_t, y_t = import_mnist()
print(y_t.shape)
model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=3)
- print(test_classifier(model, x_t, y_t))
+ print(test_classifier(model, x_t, y_t, pca=False, tsne=True))