diff options
| author | Vasil Zlatanov <v@skozl.com> | 2019-02-04 18:08:22 +0000 | 
|---|---|---|
| committer | Vasil Zlatanov <v@skozl.com> | 2019-02-04 18:08:22 +0000 | 
| commit | 5b9daa1503dce4228eb0c78f9b0a17f018cb8114 (patch) | |
| tree | 23a5c1b47fa1378381a9c8aa7b51009867843e39 | |
| parent | 94c77a3b103f7ec91b6807e1c45d910628f4bcfc (diff) | |
| download | e4-vision-5b9daa1503dce4228eb0c78f9b0a17f018cb8114.tar.gz e4-vision-5b9daa1503dce4228eb0c78f9b0a17f018cb8114.tar.bz2 e4-vision-5b9daa1503dce4228eb0c78f9b0a17f018cb8114.zip  | |
Complete question 2 by adding RF
| -rwxr-xr-x | evaluate.py | 24 | 
1 files changed, 15 insertions, 9 deletions
diff --git a/evaluate.py b/evaluate.py index b4d4a33..9b26920 100755 --- a/evaluate.py +++ b/evaluate.py @@ -4,32 +4,38 @@  DATA_FILE = 'data.npz'  CLUSTER_CNT = 256 -KMEAN_PART = 33  import numpy as np  import matplotlib.pyplot as plt  from sklearn.cluster import KMeans +from sklearn.ensemble import RandomForestClassifier  data = np.load(DATA_FILE)  train = data['desc_tr'] -train_part = data['desc_sel'].T +train_part = data['desc_sel'].T[0:1000]  print("Computing KMeans with", train_part.shape[0], "keywords")  kmeans = KMeans(n_clusters=CLUSTER_CNT, random_state=0).fit(train_part)  print("Generating histograms")  histogram = np.zeros((train.shape[0], train.shape[1],CLUSTER_CNT)) +  for i in range(train.shape[0]):      for j in range(train.shape[1]):          histogram[i][j] = np.bincount(kmeans.predict(train[i][j].T),minlength=CLUSTER_CNT) -print(histogram.shape) +print("Keywords shape", histogram.shape, "\n") +print("Planting trees...") +clf = RandomForestClassifier() +clf.fit( +        histogram.reshape((histogram.shape[0]*histogram.shape[1], histogram.shape[2])), +        np.repeat(np.arange(histogram.shape[0]), histogram.shape[1])) + +print("Random forests created") + +print(clf.score( +        histogram.reshape((histogram.shape[0]*histogram.shape[1], histogram.shape[2])), +        np.repeat(np.arange(histogram.shape[0]), histogram.shape[1]))) -plt.hist(histogram[1][5]) -plt.show() -plt.hist(histogram[3][2]) -plt.show() -plt.hist(histogram[7][8]) -plt.show()  | 
