diff options
author | nunzip <np.scarh@gmail.com> | 2019-02-07 02:26:18 +0000 |
---|---|---|
committer | nunzip <np.scarh@gmail.com> | 2019-02-07 02:26:18 +0000 |
commit | e75a0e4f63e657b216be6b500430ef35d19fe3bb (patch) | |
tree | ae725cfec28543df03a2ff10914210481f9924ec | |
parent | 13d71482da52bb3ca7c56a0054e8d151a441aee0 (diff) | |
download | e4-vision-e75a0e4f63e657b216be6b500430ef35d19fe3bb.tar.gz e4-vision-e75a0e4f63e657b216be6b500430ef35d19fe3bb.tar.bz2 e4-vision-e75a0e4f63e657b216be6b500430ef35d19fe3bb.zip |
Fix what is returned in testmode
-rwxr-xr-x | evaluate.py | 34 |
1 files changed, 19 insertions, 15 deletions
diff --git a/evaluate.py b/evaluate.py index d4d0b07..8be1412 100755 --- a/evaluate.py +++ b/evaluate.py @@ -84,15 +84,19 @@ def run_model (data, train, test, train_part, args): train_pred = clf.predict(hist_train.reshape((hist_train.shape[0]*hist_train.shape[1], hist_train.shape[2]))) train_label = np.repeat(np.arange(hist_train.shape[0]), hist_train.shape[1]) + print(test_pred) + if args.timer: end = time.time() print("Execution time: ",end - start) if args.conf_mat: skplt.metrics.plot_confusion_matrix(test_pred, test_label, normalize=True) plt.show() - - return accuracy_score(test_pred, test_label)#, accuracy_score(train_pred, train_label), end-start - + + if args.testmode: + return accuracy_score(test_pred, test_label), accuracy_score(train_pred, train_label), end-start + else: + return accuracy_score(test_pred, test_label) def main(): @@ -103,29 +107,29 @@ def main(): logging.debug("Verbose is on") if args.testmode: - acc = np.zeros((3,100)) - a = np.zeros(100) - for i in range(100): - if i <= 10: - args.kmean = (i+1) + acc = np.zeros((3,50)) + a = np.zeros(50) + for i in range(50): + if i <= 9: + args.estimators = i+1 + elif i <= 20: + args.estimators = i*3 else: - args.kmean = 15*i - a[i] = args.kmean - print("Kmeans: ",args.kmean) + args.estimators = i*10 + a[i] = args.estimators + print("Step: ",i) acc[0][i], acc[1][i], acc[2][i] = run_model (data, train, test, train_part, args) print("Accuracy test:",acc[0][i], "Accuracy train:", acc[1][i]) plt.plot(a,1-acc[0]) plt.plot(a,1-acc[1]) plt.ylabel('Normalized Classification Error') - plt.xlabel('Vocabulary size') - plt.title('Classification error varying vocabulary size') + plt.xlabel('Number of Trees') plt.legend(('Test','Train'),loc='upper right') plt.show() plt.plot(a,acc[2]) plt.ylabel('Time (s)') - plt.xlabel('Vocabulary size') - plt.title('Time complexity varying vocabulary size') + plt.xlabel('Tree Depth') plt.show() else: |