aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xevaluate.py52
1 files changed, 39 insertions, 13 deletions
diff --git a/evaluate.py b/evaluate.py
index 9b26920..9e3a613 100755
--- a/evaluate.py
+++ b/evaluate.py
@@ -4,38 +4,64 @@
DATA_FILE = 'data.npz'
CLUSTER_CNT = 256
+KMEANS = False
+
+if KMEANS:
+ N_ESTIMATORS = 10
+else:
+ N_ESTIMATORS = 1
import numpy as np
import matplotlib.pyplot as plt
+import scikitplot as skplt
from sklearn.cluster import KMeans
from sklearn.ensemble import RandomForestClassifier
+from sklearn.ensemble import RandomTreesEmbedding
data = np.load(DATA_FILE)
train = data['desc_tr']
-train_part = data['desc_sel'].T[0:1000]
+test = data['desc_te']
+train_part = data['desc_sel'].T
+
+if (KMEANS):
+ print("Computing KMeans with", train_part.shape[0], "keywords")
+ kmeans = KMeans(n_clusters=CLUSTER_CNT, n_init=N_ESTIMATORS, random_state=0).fit(train_part)
+else:
+ trees = RandomTreesEmbedding(max_leaf_nodes=256, n_estimators=N_ESTIMATORS, random_state=0).fit(train_part)
-print("Computing KMeans with", train_part.shape[0], "keywords")
-kmeans = KMeans(n_clusters=CLUSTER_CNT, random_state=0).fit(train_part)
print("Generating histograms")
-histogram = np.zeros((train.shape[0], train.shape[1],CLUSTER_CNT))
-for i in range(train.shape[0]):
- for j in range(train.shape[1]):
- histogram[i][j] = np.bincount(kmeans.predict(train[i][j].T),minlength=CLUSTER_CNT)
+def make_histogram(data):
+ histogram = np.zeros((data.shape[0], data.shape[1],CLUSTER_CNT*N_ESTIMATORS))
+ for i in range(data.shape[0]):
+ for j in range(data.shape[1]):
+ print(data[i][j].shape)
+ if (KMEANS):
+ histogram[i][j] = np.bincount(kmeans.predict(data[i][j].T),minlength=CLUSTER_CNT)
+ else:
+ leaves = trees.apply(data[i][j].T)
+ leaves = np.apply_along_axis(np.bincount, axis=0, arr=leaves, minlength=CLUSTER_CNT)
+ print(leaves.shape)
+ histogram[i][j] = leaves.reshape(CLUSTER_CNT*N_ESTIMATORS)
+ return histogram
+
+hist_train = make_histogram(train)
+hist_test = make_histogram(test)
-print("Keywords shape", histogram.shape, "\n")
+print("Keywords shape", hist_train.shape, "\n")
print("Planting trees...")
clf = RandomForestClassifier()
clf.fit(
- histogram.reshape((histogram.shape[0]*histogram.shape[1], histogram.shape[2])),
- np.repeat(np.arange(histogram.shape[0]), histogram.shape[1]))
+ hist_train.reshape((hist_train.shape[0]*hist_train.shape[1], hist_train.shape[2])),
+ np.repeat(np.arange(hist_train.shape[0]), hist_train.shape[1]))
print("Random forests created")
-print(clf.score(
- histogram.reshape((histogram.shape[0]*histogram.shape[1], histogram.shape[2])),
- np.repeat(np.arange(histogram.shape[0]), histogram.shape[1])))
+test_pred = clf.predict(hist_test.reshape((hist_test.shape[0]*hist_test.shape[1], hist_test.shape[2])))
+test_label = np.repeat(np.arange(hist_test.shape[0]), hist_test.shape[1])
+skplt.metrics.plot_confusion_matrix(test_pred, test_label, normalize=True)
+plt.show()