From f9c0139f63438a2574d9931f732cb3aecd172485 Mon Sep 17 00:00:00 2001 From: Vasil Zlatanov Date: Sun, 10 Mar 2019 19:00:54 +0000 Subject: Fix y_true in plot_probas --- lenet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lenet.py b/lenet.py index 0fe8277..a94259b 100644 --- a/lenet.py +++ b/lenet.py @@ -144,6 +144,7 @@ def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, return model def plot_probas(model, x_test, y_true): + y_true = np.argmax(y_true, axis=1) x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant') probas = model.predict(x_test) skplt.metrics.plot_roc(y_true, probas) -- cgit v1.2.3-54-g00ecf