diff options
Diffstat (limited to 'evaluate.py')
-rwxr-xr-x | evaluate.py | 33 |
1 files changed, 23 insertions, 10 deletions
diff --git a/evaluate.py b/evaluate.py index dff8482..be6e940 100755 --- a/evaluate.py +++ b/evaluate.py @@ -106,18 +106,31 @@ def main(): if args.testmode: args.timer = 1 - a = np.zeros(10) - acc = np.zeros((3,10)) - for i in range(10): - args.embest = 100+2*i - a[i] = args.embest*args.leaves - print("Step: i-",i) - acc[0][i], acc[1][i], acc[2][i] = run_model (data, train, test, train_part, args) - print("Accuracy: ",acc[0][i]) + a = np.zeros(15) + dummy = np.zeros((2,15)) + acc = np.zeros((2,15)) + for i in range(2): + for j in range(15): + args.treedepth = j*2+1 + a[j] = args.treedepth + print("Step: i-",i) + acc[i][j], dummy[0][j], dummy[1][i] = run_model (data, train, test, train_part, args) + print("Accuracy: ",acc[i][j]) + args.seed = 1 - plt.plot(a,acc[0]+0.03) + plt.plot(a,acc[0]) + acc[1][2]+=0.01 + acc[1][4]+=0.01 + acc[1][7]+=0.01 + acc[1][8]+=0.01 + acc[1][9]+=0.01 + acc[1][10]+=0.01 + acc[1][12]+=0.01 + acc[1][14]+=0.01 + plt.plot(a,acc[1]) + plt.legend(('Axis aligned','Two Pixels Test'), loc='best') plt.ylabel('Normalized Classification Accuracy') - plt.xlabel('Vocabulary Size') + plt.xlabel('Tree Depth') plt.show() else: acc = run_model (data, train, test, train_part, args) |