aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-02-06 14:09:18 +0000
committernunzip <np.scarh@gmail.com>2019-02-06 14:09:18 +0000
commit8d6c8e5dfe6cfdc8ffb886940cb8370006f3d1a2 (patch)
tree1a63c7c2fb534ca474facb0e6ceb390129e18698
parent6213982b229e42fa52a758d6d63c5b5351a39aec (diff)
downloade4-vision-8d6c8e5dfe6cfdc8ffb886940cb8370006f3d1a2.tar.gz
e4-vision-8d6c8e5dfe6cfdc8ffb886940cb8370006f3d1a2.tar.bz2
e4-vision-8d6c8e5dfe6cfdc8ffb886940cb8370006f3d1a2.zip
Set random state and add tree_depth to RFC
-rwxr-xr-xevaluate.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/evaluate.py b/evaluate.py
index b524ebd..ff52c30 100755
--- a/evaluate.py
+++ b/evaluate.py
@@ -57,7 +57,7 @@ def run_model (data, train, test, train_part, args):
hist_train = make_histogram(train, kmeans, args)
hist_test = make_histogram(test, kmeans, args)
else:
- trees = RandomTreesEmbedding(max_leaf_nodes=args.leaves, max_depth=args.treedepth, n_estimators=args.estimators, random_state=0).fit(train_part)
+ trees = RandomTreesEmbedding(max_leaf_nodes=args.leaves, n_estimators=args.estimators, random_state=0).fit(train_part)
hist_train = make_histogram(train, trees, args)
hist_test = make_histogram(test, trees, args)
@@ -66,7 +66,7 @@ def run_model (data, train, test, train_part, args):
logging.debug("Keywords shape", hist_train.shape, "\n")
logging.debug("Planting trees...")
- clf = RandomForestClassifier()
+ clf = RandomForestClassifier(n_estimators=args.estimators, max_depth=args.treedepth, random_state=0)
clf.fit(
hist_train.reshape((hist_train.shape[0]*hist_train.shape[1], hist_train.shape[2])),
np.repeat(np.arange(hist_train.shape[0]), hist_train.shape[1]))
@@ -95,7 +95,6 @@ def main():
logging.debug("Verbose is on")
if args.testmode:
- args.leaves = 10000
cnt = 0
acc = np.zeros((5,5))
for i in range(5):