aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVasil Zlatanov <v@skozl.com>2018-11-19 17:48:45 +0000
committerVasil Zlatanov <v@skozl.com>2018-11-19 17:48:45 +0000
commit56997876d5f131713f5f3525d413ee391d85e482 (patch)
treededbe921d82f1fab4375839a135662476c7b5219
parent477647a16a1fabc010931c64df711bff0fe1c79d (diff)
downloadvz215_np1915-56997876d5f131713f5f3525d413ee391d85e482.tar.gz
vz215_np1915-56997876d5f131713f5f3525d413ee391d85e482.tar.bz2
vz215_np1915-56997876d5f131713f5f3525d413ee391d85e482.zip
Add cm for ensemble
-rwxr-xr-xtrain.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/train.py b/train.py
index dd3633f..3c867a2 100755
--- a/train.py
+++ b/train.py
@@ -277,6 +277,14 @@ def main():
target_pred_comb = np.zeros(target_pred.shape[1])
target_pred = target_pred.astype(int).T
+ if (args.conf_mat):
+ cm = confusion_matrix(np.tile(target_test, args.ensemble), target_pred.flatten('F'))
+ plt.matshow(cm, cmap='Blues')
+ plt.colorbar()
+ plt.ylabel('Actual')
+ plt.xlabel('Predicted')
+ plt.show()
+
for i in range(target_pred.shape[0]):
target_pred_comb[i] = np.bincount(target_pred[i]).argmax()
target_pred = target_pred_comb