From 54e248d0d28bb6faeaefa8b25195af236ec70150 Mon Sep 17 00:00:00 2001 From: nunzip Date: Fri, 8 Mar 2019 11:34:59 +0000 Subject: Fix PCA Bug --- lenet.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lenet.py b/lenet.py index 2ec7196..7c4c48c 100644 --- a/lenet.py +++ b/lenet.py @@ -171,12 +171,16 @@ def test_classifier(model, x_test, y_true, conf_mat=False, pca=False): plt.show() if pca: set_pca = PCA(n_components=2) - pca_rep = set_pca.fit_transform(x_test) - pca_rep, y_tmp = shuffle(pca_rep, y_tmp, random_state=0) - plt.scatter(pca_rep[:100, 0], pca_rep[:100, 1], c=y_true, edgecolor='none', alpha=0.5, cmap=plt.cm.get_cmap('spectral', 10)) + pca_rep = np.reshape(x_test, (x_test.shape[0], x_test.shape[1]*x_test.shape[2])) + print(pca_rep.shape) + pca_rep = set_pca.fit_transform(pca_rep) + print(pca_rep.shape) + pca_rep, y_tmp = shuffle(pca_rep, y_true, random_state=0) + plt.scatter(pca_rep[:100, 0], pca_rep[:100, 1], c=y_true[:100], edgecolor='none', alpha=0.5, cmap=plt.cm.get_cmap('Spectral', 10)) plt.xlabel('component 1') plt.ylabel('component 2') plt.colorbar(); + plt.show() return accuracy_score(y_true, y_pred), inception_score @@ -211,4 +215,4 @@ if __name__ == '__main__': x_train, y_train, x_val, y_val, x_t, y_t = import_mnist() print(y_t.shape) model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=3) - print(test_classifier(model, x_t, y_t)) + print(test_classifier(model, x_t, y_t, pca=True)) -- cgit v1.2.3-70-g09d2