diff options
| author | Vasil Zlatanov <v@skozl.com> | 2019-02-27 18:23:44 +0000 | 
|---|---|---|
| committer | Vasil Zlatanov <v@skozl.com> | 2019-02-27 18:23:44 +0000 | 
| commit | 7465d4fdde046843cb8bca3b233c2cdd99c39722 (patch) | |
| tree | 43beef701365cbb249cf4aac04b171665e09fd99 | |
| parent | ab1f096a45c2a42a2a3d6e8bf9e60ef0f1ba1b51 (diff) | |
| download | e4-gan-7465d4fdde046843cb8bca3b233c2cdd99c39722.tar.gz e4-gan-7465d4fdde046843cb8bca3b233c2cdd99c39722.tar.bz2 e4-gan-7465d4fdde046843cb8bca3b233c2cdd99c39722.zip | |
Fix lenet funcs
| -rw-r--r-- | lenet.py | 28 | 
1 files changed, 18 insertions, 10 deletions
| @@ -8,7 +8,7 @@ from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D  from tensorflow.keras import backend as K  from tensorflow.keras import optimizers  import matplotlib.pyplot as plt -import tensorflow.keras.metrics +from tensorflow.keras.metrics import categorical_accuracy  import numpy as np  import random @@ -21,8 +21,9 @@ def import_mnist():    X_train = np.pad(X_train, ((0,0),(2,2),(2,2),(0,0)), 'constant')    X_validation = np.pad(X_validation, ((0,0),(2,2),(2,2),(0,0)), 'constant')    X_test = np.pad(X_test, ((0,0),(2,2),(2,2),(0,0)), 'constant') -   -  return X_train, X_validation, X_test, y_train, y_validation, y_test +  y_train = keras.utils.to_categorical(y_train, 10) +  y_validation = keras.utils.to_categorical(y_validation, 10)   +  return X_train, y_train, X_validation, y_validation, X_test, y_test  def plot_images(images, cls_true, cls_pred=None):      assert len(images) == len(cls_true) == 9 @@ -44,8 +45,12 @@ def plot_images(images, cls_true, cls_pred=None):      plt.show()  def plot_example_errors(y_pred, y_true, X_test): +    y_pred = np.argmax(y_pred, axis=1) +    y_true = np.argmax(y_true, axis=1)      correct_prediction = np.equal(y_pred, y_true)      incorrect = np.equal(correct_prediction, False) +    print(correct_prediction.shape) +    print(incorrect[0])      images = X_test[incorrect]      cls_pred = y_pred[incorrect]      cls_true = y_true[incorrect] @@ -53,7 +58,7 @@ def plot_example_errors(y_pred, y_true, X_test):  def get_lenet(shape):    model = keras.Sequential() -  model.add(Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=shape))) +  model.add(Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=shape))    model.add(AveragePooling2D())    model.add(Conv2D(filters=16, kernel_size=(3, 3), activation='relu')) @@ -101,20 +106,23 @@ def plot_history(history, metric = None):      plt.ylabel('Loss')      plt.xlabel('Epoch') -def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, EPOCHS=100, num_classes=10): -  y_train = keras.utils.to_categorical(y_train, num_classes) -  y_val = keras.utils.to_categorical(y_val, num_classes)   +def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100):    shape = (32, 32, 1)    model = get_lenet(shape)    sgd = optimizers.SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)    model.compile(loss='categorical_crossentropy', optimizer=sgd) -  history = model.fit(x_train, y_train, batch_size=batch_size, epochs=EPOCHS, verbose=1, validation_data = (x_val, y_val)) +  history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data = (x_val, y_val))    plot_history(history)    return model   def test_classifier(model, x_test, y_true):    y_pred = model.predict(x_test) -  print(metrics.categorical_accuracy(y_true, y_pred)) -  plot_example_errors(y_pred, y_true, x_test)  +  print(categorical_accuracy(y_true, y_pred)) +  plot_example_errors(y_pred, y_true, x_test) + +# If file run directly, perform quick test +x_train, y_train, x_val, y_val, _, _ = import_mnist() +model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=1) +test_classifier(model, x_val, y_val) | 
