diff options
author | Vasil Zlatanov <vasil@netcraft.com> | 2019-03-06 23:16:36 +0000 |
---|---|---|
committer | Vasil Zlatanov <vasil@netcraft.com> | 2019-03-06 23:16:36 +0000 |
commit | b418990448f461da50a732b4e66dd8e9066199d8 (patch) | |
tree | 1a9f4b63bd1e29f6b74cffe070d8b1ca421e9a17 /lenet.py | |
parent | 2ebd62018f4aec3d2e4c1ce14b7b85a5d46309e9 (diff) | |
download | e4-gan-b418990448f461da50a732b4e66dd8e9066199d8.tar.gz e4-gan-b418990448f461da50a732b4e66dd8e9066199d8.tar.bz2 e4-gan-b418990448f461da50a732b4e66dd8e9066199d8.zip |
Return inception score as well
Diffstat (limited to 'lenet.py')
-rw-r--r-- | lenet.py | 8 |
1 files changed, 6 insertions, 2 deletions
@@ -13,6 +13,8 @@ import random from sklearn.metrics import accuracy_score from sklearn.model_selection import train_test_split +from classifier_metrics_impl import classifier_score_from_logits + def import_mnist(): from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", reshape=False) @@ -126,10 +128,12 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, def test_classifier(model, x_test, y_true): x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant') y_pred = model.predict(x_test) + logits = tf.convert_to_tensor(y_pred, dtype=tf.float32) + inception_score = tf.keras.backend.eval(classifier_score_from_logits(logits)) y_pred = np.argmax(y_pred, axis=1) y_true = np.argmax(y_true, axis=1) plot_example_errors(y_pred, y_true, x_test) - return accuracy_score(y_true, y_pred) + return accuracy_score(y_true, y_pred), inception_score def mix_data(X_train, y_train, X_validation, y_validation, train_gen, tr_labels_gen, val_gen, val_labels_gen, split=0): @@ -162,4 +166,4 @@ if __name__ == '__main__': x_train, y_train, x_val, y_val, x_t, y_t = import_mnist() print(y_t.shape) model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=3) - test_classifier(model, x_t, y_t) + print(test_classifier(model, x_t, y_t)) |