aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVasil Zlatanov <v@skozl.com>2019-02-04 18:08:22 +0000
committerVasil Zlatanov <v@skozl.com>2019-02-04 18:08:22 +0000
commit5b9daa1503dce4228eb0c78f9b0a17f018cb8114 (patch)
tree23a5c1b47fa1378381a9c8aa7b51009867843e39
parent94c77a3b103f7ec91b6807e1c45d910628f4bcfc (diff)
downloade4-vision-5b9daa1503dce4228eb0c78f9b0a17f018cb8114.tar.gz
e4-vision-5b9daa1503dce4228eb0c78f9b0a17f018cb8114.tar.bz2
e4-vision-5b9daa1503dce4228eb0c78f9b0a17f018cb8114.zip
Complete question 2 by adding RF
-rwxr-xr-xevaluate.py24
1 files changed, 15 insertions, 9 deletions
diff --git a/evaluate.py b/evaluate.py
index b4d4a33..9b26920 100755
--- a/evaluate.py
+++ b/evaluate.py
@@ -4,32 +4,38 @@
DATA_FILE = 'data.npz'
CLUSTER_CNT = 256
-KMEAN_PART = 33
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
+from sklearn.ensemble import RandomForestClassifier
data = np.load(DATA_FILE)
train = data['desc_tr']
-train_part = data['desc_sel'].T
+train_part = data['desc_sel'].T[0:1000]
print("Computing KMeans with", train_part.shape[0], "keywords")
kmeans = KMeans(n_clusters=CLUSTER_CNT, random_state=0).fit(train_part)
print("Generating histograms")
histogram = np.zeros((train.shape[0], train.shape[1],CLUSTER_CNT))
+
for i in range(train.shape[0]):
for j in range(train.shape[1]):
histogram[i][j] = np.bincount(kmeans.predict(train[i][j].T),minlength=CLUSTER_CNT)
-print(histogram.shape)
+print("Keywords shape", histogram.shape, "\n")
+print("Planting trees...")
+clf = RandomForestClassifier()
+clf.fit(
+ histogram.reshape((histogram.shape[0]*histogram.shape[1], histogram.shape[2])),
+ np.repeat(np.arange(histogram.shape[0]), histogram.shape[1]))
+
+print("Random forests created")
+
+print(clf.score(
+ histogram.reshape((histogram.shape[0]*histogram.shape[1], histogram.shape[2])),
+ np.repeat(np.arange(histogram.shape[0]), histogram.shape[1])))
-plt.hist(histogram[1][5])
-plt.show()
-plt.hist(histogram[3][2])
-plt.show()
-plt.hist(histogram[7][8])
-plt.show()