From 8d6c8e5dfe6cfdc8ffb886940cb8370006f3d1a2 Mon Sep 17 00:00:00 2001 From: nunzip Date: Wed, 6 Feb 2019 14:09:18 +0000 Subject: Set random state and add tree_depth to RFC --- evaluate.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/evaluate.py b/evaluate.py index b524ebd..ff52c30 100755 --- a/evaluate.py +++ b/evaluate.py @@ -57,7 +57,7 @@ def run_model (data, train, test, train_part, args): hist_train = make_histogram(train, kmeans, args) hist_test = make_histogram(test, kmeans, args) else: - trees = RandomTreesEmbedding(max_leaf_nodes=args.leaves, max_depth=args.treedepth, n_estimators=args.estimators, random_state=0).fit(train_part) + trees = RandomTreesEmbedding(max_leaf_nodes=args.leaves, n_estimators=args.estimators, random_state=0).fit(train_part) hist_train = make_histogram(train, trees, args) hist_test = make_histogram(test, trees, args) @@ -66,7 +66,7 @@ def run_model (data, train, test, train_part, args): logging.debug("Keywords shape", hist_train.shape, "\n") logging.debug("Planting trees...") - clf = RandomForestClassifier() + clf = RandomForestClassifier(n_estimators=args.estimators, max_depth=args.treedepth, random_state=0) clf.fit( hist_train.reshape((hist_train.shape[0]*hist_train.shape[1], hist_train.shape[2])), np.repeat(np.arange(hist_train.shape[0]), hist_train.shape[1])) @@ -95,7 +95,6 @@ def main(): logging.debug("Verbose is on") if args.testmode: - args.leaves = 10000 cnt = 0 acc = np.zeros((5,5)) for i in range(5): -- cgit v1.2.3-54-g00ecf