diff options
| author | Vasil Zlatanov <v@skozl.com> | 2019-02-27 18:53:17 +0000 | 
|---|---|---|
| committer | Vasil Zlatanov <v@skozl.com> | 2019-02-27 18:53:17 +0000 | 
| commit | a77a1be3117b8018fdb10be3e85e258d5f5c53b4 (patch) | |
| tree | ec8820245762e11a76fce5542e9a0d8d56cffc5b | |
| parent | 53f94e754faabe129075b1c288c3c109376c34e8 (diff) | |
| download | e4-gan-a77a1be3117b8018fdb10be3e85e258d5f5c53b4.tar.gz e4-gan-a77a1be3117b8018fdb10be3e85e258d5f5c53b4.tar.bz2 e4-gan-a77a1be3117b8018fdb10be3e85e258d5f5c53b4.zip | |
Move padding to train/test functions
| -rw-r--r-- | lenet.py | 9 | 
1 files changed, 6 insertions, 3 deletions
| @@ -18,9 +18,6 @@ def import_mnist():    X_train, y_train = mnist.train.images, mnist.train.labels    X_validation, y_validation = mnist.validation.images, mnist.validation.labels    X_test, y_test = mnist.test.images, mnist.test.labels -  X_train = np.pad(X_train, ((0,0),(2,2),(2,2),(0,0)), 'constant') -  X_validation = np.pad(X_validation, ((0,0),(2,2),(2,2),(0,0)), 'constant') -  X_test = np.pad(X_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')    y_train = keras.utils.to_categorical(y_train, 10)    y_validation = keras.utils.to_categorical(y_validation, 10)      return X_train, y_train, X_validation, y_validation, X_test, y_test @@ -108,6 +105,11 @@ def plot_history(history, metric = None):  def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, metrics=['categorical_accuracy'], optimizer = None):    shape = (32, 32, 1) + +  # Pad data to 32x32 (MNIST is 28x28) +  x_train = np.pad(x_train, ((0,0),(2,2),(2,2),(0,0)), 'constant') +  x_val = np.pad(x_val, ((0,0),(2,2),(2,2),(0,0)), 'constant') +    model = get_lenet(shape)    if optimizer = None: @@ -120,6 +122,7 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100,    return model   def test_classifier(model, x_test, y_true): +  x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')    y_pred = model.predict(x_test)    print(categorical_accuracy(y_true, y_pred))    plot_example_errors(y_pred, y_true, x_test) | 
