diff options
-rwxr-xr-x | evaluate.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/evaluate.py b/evaluate.py index 92c4107..2de8f8c 100755 --- a/evaluate.py +++ b/evaluate.py @@ -7,9 +7,9 @@ CLUSTER_CNT = 256 KMEANS = False if KMEANS: - N_ESTIMATORS = 1000 -else: N_ESTIMATORS = 1 +else: + N_ESTIMATORS = 100 import numpy as np import matplotlib.pyplot as plt @@ -38,13 +38,11 @@ def make_histogram(data): histogram = np.zeros((data.shape[0], data.shape[1],CLUSTER_CNT*N_ESTIMATORS)) for i in range(data.shape[0]): for j in range(data.shape[1]): - print(data[i][j].shape) if (KMEANS): histogram[i][j] = np.bincount(kmeans.predict(data[i][j].T),minlength=CLUSTER_CNT) else: leaves = trees.apply(data[i][j].T) leaves = np.apply_along_axis(np.bincount, axis=0, arr=leaves, minlength=CLUSTER_CNT) - print(leaves.shape) histogram[i][j] = leaves.reshape(CLUSTER_CNT*N_ESTIMATORS) return histogram |