diff options
| author | Vasil Zlatanov <v@skozl.com> | 2019-02-27 19:48:46 +0000 | 
|---|---|---|
| committer | Vasil Zlatanov <v@skozl.com> | 2019-02-27 19:48:46 +0000 | 
| commit | 7eae715bbac0a3ace5075e79e1b85b2ceb45c278 (patch) | |
| tree | 59df8d0ea01d59723f21b38e2a337ba75c83b788 | |
| parent | fae3543a1a5ee731dd4f961a860c856521f7f3fd (diff) | |
| download | e4-gan-7eae715bbac0a3ace5075e79e1b85b2ceb45c278.tar.gz e4-gan-7eae715bbac0a3ace5075e79e1b85b2ceb45c278.tar.bz2 e4-gan-7eae715bbac0a3ace5075e79e1b85b2ceb45c278.zip | |
Also one_hot test
| -rw-r--r-- | lenet.py | 5 | 
1 files changed, 3 insertions, 2 deletions
| @@ -20,6 +20,7 @@ def import_mnist():    X_test, y_test = mnist.test.images, mnist.test.labels    y_train = keras.utils.to_categorical(y_train, 10)    y_validation = keras.utils.to_categorical(y_validation, 10)   +  y_test = keras.utils.to_categorical(y_test, 10)    return X_train, y_train, X_validation, y_validation, X_test, y_test  def plot_images(images, cls_true, cls_pred=None): @@ -127,6 +128,6 @@ def test_classifier(model, x_test, y_true):  # If file run directly, perform quick test  if __name__ == '__main__': -  x_train, y_train, x_val, y_val, _, _ = import_mnist() +  x_train, y_train, x_val, y_val, x_t, y_t = import_mnist()    model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=1) -  test_classifier(model, x_val, y_val) +  test_classifier(model, x_t, y_t) | 
