diff options
-rwxr-xr-x | train.py | 8 |
1 files changed, 8 insertions, 0 deletions
@@ -277,6 +277,14 @@ def main(): target_pred_comb = np.zeros(target_pred.shape[1]) target_pred = target_pred.astype(int).T + if (args.conf_mat): + cm = confusion_matrix(np.tile(target_test, args.ensemble), target_pred.flatten('F')) + plt.matshow(cm, cmap='Blues') + plt.colorbar() + plt.ylabel('Actual') + plt.xlabel('Predicted') + plt.show() + for i in range(target_pred.shape[0]): target_pred_comb[i] = np.bincount(target_pred[i]).argmax() target_pred = target_pred_comb |