aboutsummaryrefslogtreecommitdiff
path: root/lenet.py
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-03-06 23:58:13 +0000
committernunzip <np.scarh@gmail.com>2019-03-06 23:58:13 +0000
commit06b3e7c9fdae1f86e33f331b5f69cf326afb38e1 (patch)
tree1395fa8a74abaab88d3bf5f698469174b058b88f /lenet.py
parent1258e79ceee17b55ee87d5ac3a10ffea76a42dc5 (diff)
parent5d779afb5a9511323e3402537af172d68930d85c (diff)
downloade4-gan-06b3e7c9fdae1f86e33f331b5f69cf326afb38e1.tar.gz
e4-gan-06b3e7c9fdae1f86e33f331b5f69cf326afb38e1.tar.bz2
e4-gan-06b3e7c9fdae1f86e33f331b5f69cf326afb38e1.zip
Merge branch 'master' of skozl.com:e4-gan
Diffstat (limited to 'lenet.py')
-rw-r--r--lenet.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/lenet.py b/lenet.py
index 3ddab06..97479ed 100644
--- a/lenet.py
+++ b/lenet.py
@@ -13,6 +13,8 @@ import random
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
+from classifier_metrics_impl import classifier_score_from_logits
+
def import_mnist():
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", reshape=False)
@@ -62,7 +64,8 @@ def get_lenet(shape):
model.add(Dense(units=120, activation='relu'))
model.add(Dense(units=84, activation='relu'))
- model.add(Dense(units=10, activation = 'softmax'))
+ #model.add(Dense(units=10, activation = 'softmax'))
+ model.add(Dense(units=10, activation = 'relu'))
return model
def plot_history(history, metric = None):
@@ -126,10 +129,12 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100,
def test_classifier(model, x_test, y_true):
x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
y_pred = model.predict(x_test)
+ logits = tf.convert_to_tensor(y_pred, dtype=tf.float32)
+ inception_score = tf.keras.backend.eval(classifier_score_from_logits(logits))
y_pred = np.argmax(y_pred, axis=1)
y_true = np.argmax(y_true, axis=1)
plot_example_errors(y_pred, y_true, x_test)
- return accuracy_score(y_true, y_pred)
+ return accuracy_score(y_true, y_pred), inception_score
def mix_data(X_train, y_train, X_validation, y_validation, train_gen, tr_labels_gen, val_gen, val_labels_gen, split=0):
@@ -162,4 +167,4 @@ if __name__ == '__main__':
x_train, y_train, x_val, y_val, x_t, y_t = import_mnist()
print(y_t.shape)
model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=3)
- test_classifier(model, x_t, y_t)
+ print(test_classifier(model, x_t, y_t))