aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xtrain.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/train.py b/train.py
index dd3633f..3c867a2 100755
--- a/train.py
+++ b/train.py
@@ -277,6 +277,14 @@ def main():
target_pred_comb = np.zeros(target_pred.shape[1])
target_pred = target_pred.astype(int).T
+ if (args.conf_mat):
+ cm = confusion_matrix(np.tile(target_test, args.ensemble), target_pred.flatten('F'))
+ plt.matshow(cm, cmap='Blues')
+ plt.colorbar()
+ plt.ylabel('Actual')
+ plt.xlabel('Predicted')
+ plt.show()
+
for i in range(target_pred.shape[0]):
target_pred_comb[i] = np.bincount(target_pred[i]).argmax()
target_pred = target_pred_comb