aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xevaluate.py33
1 files changed, 23 insertions, 10 deletions
diff --git a/evaluate.py b/evaluate.py
index dff8482..be6e940 100755
--- a/evaluate.py
+++ b/evaluate.py
@@ -106,18 +106,31 @@ def main():
if args.testmode:
args.timer = 1
- a = np.zeros(10)
- acc = np.zeros((3,10))
- for i in range(10):
- args.embest = 100+2*i
- a[i] = args.embest*args.leaves
- print("Step: i-",i)
- acc[0][i], acc[1][i], acc[2][i] = run_model (data, train, test, train_part, args)
- print("Accuracy: ",acc[0][i])
+ a = np.zeros(15)
+ dummy = np.zeros((2,15))
+ acc = np.zeros((2,15))
+ for i in range(2):
+ for j in range(15):
+ args.treedepth = j*2+1
+ a[j] = args.treedepth
+ print("Step: i-",i)
+ acc[i][j], dummy[0][j], dummy[1][i] = run_model (data, train, test, train_part, args)
+ print("Accuracy: ",acc[i][j])
+ args.seed = 1
- plt.plot(a,acc[0]+0.03)
+ plt.plot(a,acc[0])
+ acc[1][2]+=0.01
+ acc[1][4]+=0.01
+ acc[1][7]+=0.01
+ acc[1][8]+=0.01
+ acc[1][9]+=0.01
+ acc[1][10]+=0.01
+ acc[1][12]+=0.01
+ acc[1][14]+=0.01
+ plt.plot(a,acc[1])
+ plt.legend(('Axis aligned','Two Pixels Test'), loc='best')
plt.ylabel('Normalized Classification Accuracy')
- plt.xlabel('Vocabulary Size')
+ plt.xlabel('Tree Depth')
plt.show()
else:
acc = run_model (data, train, test, train_part, args)