aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-02-07 02:26:18 +0000
committernunzip <np.scarh@gmail.com>2019-02-07 02:26:18 +0000
commite75a0e4f63e657b216be6b500430ef35d19fe3bb (patch)
treeae725cfec28543df03a2ff10914210481f9924ec
parent13d71482da52bb3ca7c56a0054e8d151a441aee0 (diff)
downloade4-vision-e75a0e4f63e657b216be6b500430ef35d19fe3bb.tar.gz
e4-vision-e75a0e4f63e657b216be6b500430ef35d19fe3bb.tar.bz2
e4-vision-e75a0e4f63e657b216be6b500430ef35d19fe3bb.zip
Fix what is returned in testmode
-rwxr-xr-xevaluate.py34
1 files changed, 19 insertions, 15 deletions
diff --git a/evaluate.py b/evaluate.py
index d4d0b07..8be1412 100755
--- a/evaluate.py
+++ b/evaluate.py
@@ -84,15 +84,19 @@ def run_model (data, train, test, train_part, args):
train_pred = clf.predict(hist_train.reshape((hist_train.shape[0]*hist_train.shape[1], hist_train.shape[2])))
train_label = np.repeat(np.arange(hist_train.shape[0]), hist_train.shape[1])
+ print(test_pred)
+
if args.timer:
end = time.time()
print("Execution time: ",end - start)
if args.conf_mat:
skplt.metrics.plot_confusion_matrix(test_pred, test_label, normalize=True)
plt.show()
-
- return accuracy_score(test_pred, test_label)#, accuracy_score(train_pred, train_label), end-start
-
+
+ if args.testmode:
+ return accuracy_score(test_pred, test_label), accuracy_score(train_pred, train_label), end-start
+ else:
+ return accuracy_score(test_pred, test_label)
def main():
@@ -103,29 +107,29 @@ def main():
logging.debug("Verbose is on")
if args.testmode:
- acc = np.zeros((3,100))
- a = np.zeros(100)
- for i in range(100):
- if i <= 10:
- args.kmean = (i+1)
+ acc = np.zeros((3,50))
+ a = np.zeros(50)
+ for i in range(50):
+ if i <= 9:
+ args.estimators = i+1
+ elif i <= 20:
+ args.estimators = i*3
else:
- args.kmean = 15*i
- a[i] = args.kmean
- print("Kmeans: ",args.kmean)
+ args.estimators = i*10
+ a[i] = args.estimators
+ print("Step: ",i)
acc[0][i], acc[1][i], acc[2][i] = run_model (data, train, test, train_part, args)
print("Accuracy test:",acc[0][i], "Accuracy train:", acc[1][i])
plt.plot(a,1-acc[0])
plt.plot(a,1-acc[1])
plt.ylabel('Normalized Classification Error')
- plt.xlabel('Vocabulary size')
- plt.title('Classification error varying vocabulary size')
+ plt.xlabel('Number of Trees')
plt.legend(('Test','Train'),loc='upper right')
plt.show()
plt.plot(a,acc[2])
plt.ylabel('Time (s)')
- plt.xlabel('Vocabulary size')
- plt.title('Time complexity varying vocabulary size')
+ plt.xlabel('Tree Depth')
plt.show()
else: