diff options
author | Vasil Zlatanov <v@skozl.com> | 2019-02-27 18:53:17 +0000 |
---|---|---|
committer | Vasil Zlatanov <v@skozl.com> | 2019-02-27 18:53:17 +0000 |
commit | a77a1be3117b8018fdb10be3e85e258d5f5c53b4 (patch) | |
tree | ec8820245762e11a76fce5542e9a0d8d56cffc5b | |
parent | 53f94e754faabe129075b1c288c3c109376c34e8 (diff) | |
download | e4-gan-a77a1be3117b8018fdb10be3e85e258d5f5c53b4.tar.gz e4-gan-a77a1be3117b8018fdb10be3e85e258d5f5c53b4.tar.bz2 e4-gan-a77a1be3117b8018fdb10be3e85e258d5f5c53b4.zip |
Move padding to train/test functions
-rw-r--r-- | lenet.py | 9 |
1 files changed, 6 insertions, 3 deletions
@@ -18,9 +18,6 @@ def import_mnist(): X_train, y_train = mnist.train.images, mnist.train.labels X_validation, y_validation = mnist.validation.images, mnist.validation.labels X_test, y_test = mnist.test.images, mnist.test.labels - X_train = np.pad(X_train, ((0,0),(2,2),(2,2),(0,0)), 'constant') - X_validation = np.pad(X_validation, ((0,0),(2,2),(2,2),(0,0)), 'constant') - X_test = np.pad(X_test, ((0,0),(2,2),(2,2),(0,0)), 'constant') y_train = keras.utils.to_categorical(y_train, 10) y_validation = keras.utils.to_categorical(y_validation, 10) return X_train, y_train, X_validation, y_validation, X_test, y_test @@ -108,6 +105,11 @@ def plot_history(history, metric = None): def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, metrics=['categorical_accuracy'], optimizer = None): shape = (32, 32, 1) + + # Pad data to 32x32 (MNIST is 28x28) + x_train = np.pad(x_train, ((0,0),(2,2),(2,2),(0,0)), 'constant') + x_val = np.pad(x_val, ((0,0),(2,2),(2,2),(0,0)), 'constant') + model = get_lenet(shape) if optimizer = None: @@ -120,6 +122,7 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, return model def test_classifier(model, x_test, y_true): + x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant') y_pred = model.predict(x_test) print(categorical_accuracy(y_true, y_pred)) plot_example_errors(y_pred, y_true, x_test) |