From c9220a2981d3f2b2f6dfc8d51834801b604c9e7d Mon Sep 17 00:00:00 2001 From: Vasil Zlatanov Date: Wed, 6 Feb 2019 17:31:29 +0000 Subject: Add seed flag --- evaluate.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/evaluate.py b/evaluate.py index ff52c30..df0c79b 100755 --- a/evaluate.py +++ b/evaluate.py @@ -25,6 +25,7 @@ parser.add_argument("-D", "--treedepth", help="depth of trees", type=int, defaul parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true') parser.add_argument("-t", "--timer", help="Display execution time", action='store_true') parser.add_argument("-T", "--testmode", help="Testmode", action='store_true') +parser.add_argument("-s", "--seed", help="Seed to use for random_state when creating trees", type=int, default=0) args = parser.parse_args() if args.verbose: @@ -53,11 +54,11 @@ def run_model (data, train, test, train_part, args): if (args.kmean): logging.debug("Computing KMeans with", train_part.shape[0], "keywords") - kmeans = KMeans(n_clusters=args.kmean, n_init=args.estimators, random_state=0).fit(train_part) + kmeans = KMeans(n_clusters=args.kmean, n_init=args.estimators, random_state=args.seed).fit(train_part) hist_train = make_histogram(train, kmeans, args) hist_test = make_histogram(test, kmeans, args) else: - trees = RandomTreesEmbedding(max_leaf_nodes=args.leaves, n_estimators=args.estimators, random_state=0).fit(train_part) + trees = RandomTreesEmbedding(max_leaf_nodes=args.leaves, n_estimators=args.estimators, random_state=args.seed).fit(train_part) hist_train = make_histogram(train, trees, args) hist_test = make_histogram(test, trees, args) @@ -66,7 +67,7 @@ def run_model (data, train, test, train_part, args): logging.debug("Keywords shape", hist_train.shape, "\n") logging.debug("Planting trees...") - clf = RandomForestClassifier(n_estimators=args.estimators, max_depth=args.treedepth, random_state=0) + clf = RandomForestClassifier(n_estimators=args.estimators, max_depth=args.treedepth, random_state=args.seed) clf.fit( hist_train.reshape((hist_train.shape[0]*hist_train.shape[1], hist_train.shape[2])), np.repeat(np.arange(hist_train.shape[0]), hist_train.shape[1])) -- cgit v1.2.3-54-g00ecf