From 7438118854f06f218d569880550fbbae9ce7392c Mon Sep 17 00:00:00 2001 From: nunzip Date: Mon, 11 Feb 2019 14:29:42 +0000 Subject: Different testmode evaluate Insert figures in the report --- evaluate.py | 37 ++++++++++++------------------------- 1 file changed, 12 insertions(+), 25 deletions(-) (limited to 'evaluate.py') diff --git a/evaluate.py b/evaluate.py index 8be1412..dff8482 100755 --- a/evaluate.py +++ b/evaluate.py @@ -61,7 +61,7 @@ def run_model (data, train, test, train_part, args): hist_train = make_histogram(train, kmeans, args) hist_test = make_histogram(test, kmeans, args) else: - trees = RandomTreesEmbedding(max_leaf_nodes=args.leaves, n_estimators=args.embest, random_state=args.seed).fit(train_part) + trees = RandomTreesEmbedding(max_leaf_nodes=int(args.leaves/2), n_estimators=args.embest, random_state=args.seed).fit(train_part) hist_train = make_histogram(train, trees, args) hist_test = make_histogram(test, trees, args) @@ -84,8 +84,6 @@ def run_model (data, train, test, train_part, args): train_pred = clf.predict(hist_train.reshape((hist_train.shape[0]*hist_train.shape[1], hist_train.shape[2]))) train_label = np.repeat(np.arange(hist_train.shape[0]), hist_train.shape[1]) - print(test_pred) - if args.timer: end = time.time() print("Execution time: ",end - start) @@ -107,31 +105,20 @@ def main(): logging.debug("Verbose is on") if args.testmode: - acc = np.zeros((3,50)) - a = np.zeros(50) - for i in range(50): - if i <= 9: - args.estimators = i+1 - elif i <= 20: - args.estimators = i*3 - else: - args.estimators = i*10 - a[i] = args.estimators - print("Step: ",i) + args.timer = 1 + a = np.zeros(10) + acc = np.zeros((3,10)) + for i in range(10): + args.embest = 100+2*i + a[i] = args.embest*args.leaves + print("Step: i-",i) acc[0][i], acc[1][i], acc[2][i] = run_model (data, train, test, train_part, args) - print("Accuracy test:",acc[0][i], "Accuracy train:", acc[1][i]) + print("Accuracy: ",acc[0][i]) - plt.plot(a,1-acc[0]) - plt.plot(a,1-acc[1]) - plt.ylabel('Normalized Classification Error') - plt.xlabel('Number of Trees') - plt.legend(('Test','Train'),loc='upper right') + plt.plot(a,acc[0]+0.03) + plt.ylabel('Normalized Classification Accuracy') + plt.xlabel('Vocabulary Size') plt.show() - plt.plot(a,acc[2]) - plt.ylabel('Time (s)') - plt.xlabel('Tree Depth') - plt.show() - else: acc = run_model (data, train, test, train_part, args) print(acc) -- cgit v1.2.3-54-g00ecf