diff options
author | nunzip <np.scarh@gmail.com> | 2019-03-08 11:34:59 +0000 |
---|---|---|
committer | nunzip <np.scarh@gmail.com> | 2019-03-08 11:34:59 +0000 |
commit | 54e248d0d28bb6faeaefa8b25195af236ec70150 (patch) | |
tree | 39690c5ce9f9b9b1a484cbc16663507e50e579a9 | |
parent | 035a11b98dc4b78ccab8ddc35a7fceaea9bb00c6 (diff) | |
download | e4-gan-54e248d0d28bb6faeaefa8b25195af236ec70150.tar.gz e4-gan-54e248d0d28bb6faeaefa8b25195af236ec70150.tar.bz2 e4-gan-54e248d0d28bb6faeaefa8b25195af236ec70150.zip |
Fix PCA Bug
-rw-r--r-- | lenet.py | 12 |
1 files changed, 8 insertions, 4 deletions
@@ -171,12 +171,16 @@ def test_classifier(model, x_test, y_true, conf_mat=False, pca=False): plt.show() if pca: set_pca = PCA(n_components=2) - pca_rep = set_pca.fit_transform(x_test) - pca_rep, y_tmp = shuffle(pca_rep, y_tmp, random_state=0) - plt.scatter(pca_rep[:100, 0], pca_rep[:100, 1], c=y_true, edgecolor='none', alpha=0.5, cmap=plt.cm.get_cmap('spectral', 10)) + pca_rep = np.reshape(x_test, (x_test.shape[0], x_test.shape[1]*x_test.shape[2])) + print(pca_rep.shape) + pca_rep = set_pca.fit_transform(pca_rep) + print(pca_rep.shape) + pca_rep, y_tmp = shuffle(pca_rep, y_true, random_state=0) + plt.scatter(pca_rep[:100, 0], pca_rep[:100, 1], c=y_true[:100], edgecolor='none', alpha=0.5, cmap=plt.cm.get_cmap('Spectral', 10)) plt.xlabel('component 1') plt.ylabel('component 2') plt.colorbar(); + plt.show() return accuracy_score(y_true, y_pred), inception_score @@ -211,4 +215,4 @@ if __name__ == '__main__': x_train, y_train, x_val, y_val, x_t, y_t = import_mnist() print(y_t.shape) model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=3) - print(test_classifier(model, x_t, y_t)) + print(test_classifier(model, x_t, y_t, pca=True)) |