diff options
author | nunzip <np.scarh@gmail.com> | 2019-02-06 21:28:51 +0000 |
---|---|---|
committer | nunzip <np.scarh@gmail.com> | 2019-02-06 21:28:51 +0000 |
commit | 13d71482da52bb3ca7c56a0054e8d151a441aee0 (patch) | |
tree | 6b5d407854f90f7d9e8d7c40f49c5b89a3e0e71b | |
parent | 9c7644482af96b6462d6e8049d93c9f46cf29def (diff) | |
parent | c9220a2981d3f2b2f6dfc8d51834801b604c9e7d (diff) | |
download | e4-vision-13d71482da52bb3ca7c56a0054e8d151a441aee0.tar.gz e4-vision-13d71482da52bb3ca7c56a0054e8d151a441aee0.tar.bz2 e4-vision-13d71482da52bb3ca7c56a0054e8d151a441aee0.zip |
Fixed merge conflicts
-rwxr-xr-x | evaluate.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/evaluate.py b/evaluate.py index 2fb805d..d4d0b07 100755 --- a/evaluate.py +++ b/evaluate.py @@ -28,6 +28,7 @@ parser.add_argument("-t", "--timer", help="Display execution time", action='stor parser.add_argument("-T", "--testmode", help="Testmode", action='store_true') parser.add_argument("-E", "--embest", help="RandomTreesEmbedding estimators", type=int, default=256) parser.add_argument("-r", "--randomness", help="Randomness parameter", type=int, default=0) +parser.add_argument("-s", "--seed", help="Seed to use for random_state when creating trees", type=int, default=0) args = parser.parse_args() if args.verbose: @@ -56,11 +57,11 @@ def run_model (data, train, test, train_part, args): if (args.kmean): logging.debug("Computing KMeans with", train_part.shape[0], "keywords") - kmeans = KMeans(n_clusters=args.kmean, n_init=1, random_state=0).fit(train_part) + kmeans = KMeans(n_clusters=args.kmean, n_init=1, random_state=args.seed).fit(train_part) hist_train = make_histogram(train, kmeans, args) hist_test = make_histogram(test, kmeans, args) else: - trees = RandomTreesEmbedding(max_leaf_nodes=args.leaves, n_estimators=args.embest, random_state=0).fit(train_part) + trees = RandomTreesEmbedding(max_leaf_nodes=args.leaves, n_estimators=args.embest, random_state=args.seed).fit(train_part) hist_train = make_histogram(train, trees, args) hist_test = make_histogram(test, trees, args) @@ -71,7 +72,7 @@ def run_model (data, train, test, train_part, args): if args.randomness: clf = RandomForestClassifier(max_features=args.randomness, n_estimators=args.estimators, max_depth=args.treedepth, random_state=0) else: - clf = RandomForestClassifier(n_estimators=args.estimators, max_depth=args.treedepth, random_state=0) + clf = RandomForestClassifier(n_estimators=args.estimators, max_depth=args.treedepth, random_state=args.seed) clf.fit( hist_train.reshape((hist_train.shape[0]*hist_train.shape[1], hist_train.shape[2])), np.repeat(np.arange(hist_train.shape[0]), hist_train.shape[1])) |