aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVasil Zlatanov <v@skozl.com>2019-02-27 18:53:17 +0000
committerVasil Zlatanov <v@skozl.com>2019-02-27 18:53:17 +0000
commita77a1be3117b8018fdb10be3e85e258d5f5c53b4 (patch)
treeec8820245762e11a76fce5542e9a0d8d56cffc5b
parent53f94e754faabe129075b1c288c3c109376c34e8 (diff)
downloade4-gan-a77a1be3117b8018fdb10be3e85e258d5f5c53b4.tar.gz
e4-gan-a77a1be3117b8018fdb10be3e85e258d5f5c53b4.tar.bz2
e4-gan-a77a1be3117b8018fdb10be3e85e258d5f5c53b4.zip
Move padding to train/test functions
-rw-r--r--lenet.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/lenet.py b/lenet.py
index 37123bb..93a80cf 100644
--- a/lenet.py
+++ b/lenet.py
@@ -18,9 +18,6 @@ def import_mnist():
X_train, y_train = mnist.train.images, mnist.train.labels
X_validation, y_validation = mnist.validation.images, mnist.validation.labels
X_test, y_test = mnist.test.images, mnist.test.labels
- X_train = np.pad(X_train, ((0,0),(2,2),(2,2),(0,0)), 'constant')
- X_validation = np.pad(X_validation, ((0,0),(2,2),(2,2),(0,0)), 'constant')
- X_test = np.pad(X_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
y_train = keras.utils.to_categorical(y_train, 10)
y_validation = keras.utils.to_categorical(y_validation, 10)
return X_train, y_train, X_validation, y_validation, X_test, y_test
@@ -108,6 +105,11 @@ def plot_history(history, metric = None):
def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, metrics=['categorical_accuracy'], optimizer = None):
shape = (32, 32, 1)
+
+ # Pad data to 32x32 (MNIST is 28x28)
+ x_train = np.pad(x_train, ((0,0),(2,2),(2,2),(0,0)), 'constant')
+ x_val = np.pad(x_val, ((0,0),(2,2),(2,2),(0,0)), 'constant')
+
model = get_lenet(shape)
if optimizer = None:
@@ -120,6 +122,7 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100,
return model
def test_classifier(model, x_test, y_true):
+ x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
y_pred = model.predict(x_test)
print(categorical_accuracy(y_true, y_pred))
plot_example_errors(y_pred, y_true, x_test)