diff options
author | Vasil Zlatanov <v@skozl.com> | 2019-03-13 20:03:15 +0000 |
---|---|---|
committer | Vasil Zlatanov <v@skozl.com> | 2019-03-13 20:03:15 +0000 |
commit | fb6259d3285b6c3aa22069fffdb756a0342901b5 (patch) | |
tree | 599454bef502e145b0f4bb49a177e2baa22ff0bb /lenet.py | |
parent | 03f2c41ac69084cde7a61eb04303078e3c4785a7 (diff) | |
parent | 9945d9fe431f0b01c528b311acb685bebd99ab48 (diff) | |
download | e4-gan-fb6259d3285b6c3aa22069fffdb756a0342901b5.tar.gz e4-gan-fb6259d3285b6c3aa22069fffdb756a0342901b5.tar.bz2 e4-gan-fb6259d3285b6c3aa22069fffdb756a0342901b5.zip |
Merge branch 'master' of skozl.com:e4-gan
Diffstat (limited to 'lenet.py')
-rw-r--r-- | lenet.py | 58 |
1 files changed, 49 insertions, 9 deletions
@@ -11,9 +11,13 @@ from tensorflow.keras.metrics import categorical_accuracy import numpy as np import random from sklearn.metrics import accuracy_score +from sklearn.metrics import confusion_matrix from sklearn.model_selection import train_test_split - +from sklearn.decomposition import PCA from classifier_metrics_impl import classifier_score_from_logits +from sklearn.utils import shuffle +from sklearn.manifold import TSNE +import scikitplot as skplt def import_mnist(): from tensorflow.examples.tutorials.mnist import input_data @@ -64,8 +68,7 @@ def get_lenet(shape): model.add(Dense(units=120, activation='relu')) model.add(Dense(units=84, activation='relu')) - #model.add(Dense(units=10, activation = 'softmax')) - model.add(Dense(units=10, activation = 'relu')) + model.add(Dense(units=10, activation = 'softmax')) return model def get_lenet_icp(shape): @@ -140,14 +143,51 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, model.save_weights('./weights.h5') return model -def test_classifier(model, x_test, y_true): +def plot_probas(model, x_test, y_true): + y_true = np.argmax(y_true, axis=1) + x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant') + probas = model.predict(x_test) + skplt.metrics.plot_roc(y_true, probas) + plt.show() + skplt.metrics.plot_precision_recall_curve(y_true, probas) + plt.show() + +def test_classifier(model, x_test, y_true, conf_mat=False, pca=False, tsne=False): x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant') - y_pred = model.predict(x_test) - logits = tf.convert_to_tensor(y_pred, dtype=tf.float32) - inception_score = tf.keras.backend.eval(classifier_score_from_logits(logits)) - y_pred = np.argmax(y_pred, axis=1) + logits = model.predict(x_test) + tf_logits = tf.convert_to_tensor(logits, dtype=tf.float32) + inception_score = tf.keras.backend.eval(classifier_score_from_logits(tf_logits)) + y_pred = np.argmax(logits, axis=1) y_true = np.argmax(y_true, axis=1) plot_example_errors(y_pred, y_true, x_test) + cm = confusion_matrix(y_true, y_pred) + if conf_mat: + plt.matshow(cm, cmap='Blues') + plt.colorbar() + plt.ylabel('Actual') + plt.xlabel('Predicted') + plt.show() + if pca: + set_pca = PCA(n_components=2) + pca_rep = set_pca.fit_transform(logits) + pca_rep, y_tmp = shuffle(pca_rep, y_true, random_state=0) + plt.scatter(pca_rep[:5000, 0], pca_rep[:5000, 1], c=y_tmp[:5000], edgecolor='none', alpha=0.5, cmap=plt.cm.get_cmap('Paired', 10)) + plt.xlabel('Feature 1') + plt.ylabel('Feature 2') + plt.colorbar(); + plt.show() + if tsne: + tsne = TSNE(n_components=2, random_state=0) + components = tsne.fit_transform(logits) + print(components.shape) + components, y_tmp = shuffle(components, y_true, random_state=0) + plt.scatter(components[:5000, 0], components[:5000, 1], c=y_tmp[:5000], edgecolor='none', alpha=0.5, cmap=plt.cm.get_cmap('Paired', 10)) + plt.xlabel('Feature 1') + plt.ylabel('Feature 2') + plt.colorbar(); + plt.show() + + return accuracy_score(y_true, y_pred), inception_score def mix_data(X_train, y_train, X_validation, y_validation, train_gen, tr_labels_gen, val_gen, val_labels_gen, split=0): @@ -181,4 +221,4 @@ if __name__ == '__main__': x_train, y_train, x_val, y_val, x_t, y_t = import_mnist() print(y_t.shape) model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=3) - print(test_classifier(model, x_t, y_t)) + print(test_classifier(model, x_t, y_t, pca=False, tsne=True)) |