aboutsummaryrefslogtreecommitdiff
path: root/evaluate.py
diff options
context:
space:
mode:
authorVasil Zlatanov <v@skozl.com>2019-02-06 17:31:29 +0000
committerVasil Zlatanov <v@skozl.com>2019-02-06 17:31:29 +0000
commitc9220a2981d3f2b2f6dfc8d51834801b604c9e7d (patch)
treeb52fb146934e9a8257c52e534dd1b625e4a71b6e /evaluate.py
parent8d6c8e5dfe6cfdc8ffb886940cb8370006f3d1a2 (diff)
downloade4-vision-c9220a2981d3f2b2f6dfc8d51834801b604c9e7d.tar.gz
e4-vision-c9220a2981d3f2b2f6dfc8d51834801b604c9e7d.tar.bz2
e4-vision-c9220a2981d3f2b2f6dfc8d51834801b604c9e7d.zip
Add seed flag
Diffstat (limited to 'evaluate.py')
-rwxr-xr-xevaluate.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/evaluate.py b/evaluate.py
index ff52c30..df0c79b 100755
--- a/evaluate.py
+++ b/evaluate.py
@@ -25,6 +25,7 @@ parser.add_argument("-D", "--treedepth", help="depth of trees", type=int, defaul
parser.add_argument("-v", "--verbose", help="Use verbose output", action='store_true')
parser.add_argument("-t", "--timer", help="Display execution time", action='store_true')
parser.add_argument("-T", "--testmode", help="Testmode", action='store_true')
+parser.add_argument("-s", "--seed", help="Seed to use for random_state when creating trees", type=int, default=0)
args = parser.parse_args()
if args.verbose:
@@ -53,11 +54,11 @@ def run_model (data, train, test, train_part, args):
if (args.kmean):
logging.debug("Computing KMeans with", train_part.shape[0], "keywords")
- kmeans = KMeans(n_clusters=args.kmean, n_init=args.estimators, random_state=0).fit(train_part)
+ kmeans = KMeans(n_clusters=args.kmean, n_init=args.estimators, random_state=args.seed).fit(train_part)
hist_train = make_histogram(train, kmeans, args)
hist_test = make_histogram(test, kmeans, args)
else:
- trees = RandomTreesEmbedding(max_leaf_nodes=args.leaves, n_estimators=args.estimators, random_state=0).fit(train_part)
+ trees = RandomTreesEmbedding(max_leaf_nodes=args.leaves, n_estimators=args.estimators, random_state=args.seed).fit(train_part)
hist_train = make_histogram(train, trees, args)
hist_test = make_histogram(test, trees, args)
@@ -66,7 +67,7 @@ def run_model (data, train, test, train_part, args):
logging.debug("Keywords shape", hist_train.shape, "\n")
logging.debug("Planting trees...")
- clf = RandomForestClassifier(n_estimators=args.estimators, max_depth=args.treedepth, random_state=0)
+ clf = RandomForestClassifier(n_estimators=args.estimators, max_depth=args.treedepth, random_state=args.seed)
clf.fit(
hist_train.reshape((hist_train.shape[0]*hist_train.shape[1], hist_train.shape[2])),
np.repeat(np.arange(hist_train.shape[0]), hist_train.shape[1]))