aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-03-08 11:34:59 +0000
committernunzip <np.scarh@gmail.com>2019-03-08 11:34:59 +0000
commit54e248d0d28bb6faeaefa8b25195af236ec70150 (patch)
tree39690c5ce9f9b9b1a484cbc16663507e50e579a9
parent035a11b98dc4b78ccab8ddc35a7fceaea9bb00c6 (diff)
downloade4-gan-54e248d0d28bb6faeaefa8b25195af236ec70150.tar.gz
e4-gan-54e248d0d28bb6faeaefa8b25195af236ec70150.tar.bz2
e4-gan-54e248d0d28bb6faeaefa8b25195af236ec70150.zip
Fix PCA Bug
-rw-r--r--lenet.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/lenet.py b/lenet.py
index 2ec7196..7c4c48c 100644
--- a/lenet.py
+++ b/lenet.py
@@ -171,12 +171,16 @@ def test_classifier(model, x_test, y_true, conf_mat=False, pca=False):
plt.show()
if pca:
set_pca = PCA(n_components=2)
- pca_rep = set_pca.fit_transform(x_test)
- pca_rep, y_tmp = shuffle(pca_rep, y_tmp, random_state=0)
- plt.scatter(pca_rep[:100, 0], pca_rep[:100, 1], c=y_true, edgecolor='none', alpha=0.5, cmap=plt.cm.get_cmap('spectral', 10))
+ pca_rep = np.reshape(x_test, (x_test.shape[0], x_test.shape[1]*x_test.shape[2]))
+ print(pca_rep.shape)
+ pca_rep = set_pca.fit_transform(pca_rep)
+ print(pca_rep.shape)
+ pca_rep, y_tmp = shuffle(pca_rep, y_true, random_state=0)
+ plt.scatter(pca_rep[:100, 0], pca_rep[:100, 1], c=y_true[:100], edgecolor='none', alpha=0.5, cmap=plt.cm.get_cmap('Spectral', 10))
plt.xlabel('component 1')
plt.ylabel('component 2')
plt.colorbar();
+ plt.show()
return accuracy_score(y_true, y_pred), inception_score
@@ -211,4 +215,4 @@ if __name__ == '__main__':
x_train, y_train, x_val, y_val, x_t, y_t = import_mnist()
print(y_t.shape)
model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=3)
- print(test_classifier(model, x_t, y_t))
+ print(test_classifier(model, x_t, y_t, pca=True))