diff options
-rw-r--r-- | lenet.py | 5 |
1 files changed, 3 insertions, 2 deletions
@@ -20,6 +20,7 @@ def import_mnist(): X_test, y_test = mnist.test.images, mnist.test.labels y_train = keras.utils.to_categorical(y_train, 10) y_validation = keras.utils.to_categorical(y_validation, 10) + y_test = keras.utils.to_categorical(y_test, 10) return X_train, y_train, X_validation, y_validation, X_test, y_test def plot_images(images, cls_true, cls_pred=None): @@ -127,6 +128,6 @@ def test_classifier(model, x_test, y_true): # If file run directly, perform quick test if __name__ == '__main__': - x_train, y_train, x_val, y_val, _, _ = import_mnist() + x_train, y_train, x_val, y_val, x_t, y_t = import_mnist() model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=1) - test_classifier(model, x_val, y_val) + test_classifier(model, x_t, y_t) |