aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lenet.py27
1 files changed, 25 insertions, 2 deletions
diff --git a/lenet.py b/lenet.py
index 3440af7..2ec7196 100644
--- a/lenet.py
+++ b/lenet.py
@@ -13,8 +13,9 @@ import random
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
-
+from sklearn.decomposition import PCA
from classifier_metrics_impl import classifier_score_from_logits
+from sklearn.utils import shuffle
def import_mnist():
from tensorflow.examples.tutorials.mnist import input_data
@@ -82,6 +83,19 @@ def get_lenet_icp(shape):
model.add(Dense(units=10, activation = 'relu'))
return model
+def get_lenet_pen(shape):
+ model = keras.Sequential()
+ model.add(Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=(32,32,1)))
+ model.add(AveragePooling2D())
+
+ model.add(Conv2D(filters=16, kernel_size=(3, 3), activation='relu'))
+ model.add(AveragePooling2D())
+ model.add(Flatten())
+
+ model.add(Dense(units=120, activation='relu'))
+ model.add(Dense(units=84, activation='relu'))
+ return model
+
def plot_history(history, metric = None):
# Plots the loss history of training and validation (if existing)
# and a given metric
@@ -140,7 +154,7 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100,
model.save_weights('./weights.h5')
return model
-def test_classifier(model, x_test, y_true, conf_mat=False):
+def test_classifier(model, x_test, y_true, conf_mat=False, pca=False):
x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
y_pred = model.predict(x_test)
logits = tf.convert_to_tensor(y_pred, dtype=tf.float32)
@@ -155,6 +169,15 @@ def test_classifier(model, x_test, y_true, conf_mat=False):
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.show()
+ if pca:
+ set_pca = PCA(n_components=2)
+ pca_rep = set_pca.fit_transform(x_test)
+ pca_rep, y_tmp = shuffle(pca_rep, y_tmp, random_state=0)
+ plt.scatter(pca_rep[:100, 0], pca_rep[:100, 1], c=y_true, edgecolor='none', alpha=0.5, cmap=plt.cm.get_cmap('spectral', 10))
+ plt.xlabel('component 1')
+ plt.ylabel('component 2')
+ plt.colorbar();
+
return accuracy_score(y_true, y_pred), inception_score
def mix_data(X_train, y_train, X_validation, y_validation, train_gen, tr_labels_gen, val_gen, val_labels_gen, split=0):