diff options
| -rwxr-xr-x | evaluate.py | 24 | 
1 files changed, 15 insertions, 9 deletions
diff --git a/evaluate.py b/evaluate.py index b4d4a33..9b26920 100755 --- a/evaluate.py +++ b/evaluate.py @@ -4,32 +4,38 @@  DATA_FILE = 'data.npz'  CLUSTER_CNT = 256 -KMEAN_PART = 33  import numpy as np  import matplotlib.pyplot as plt  from sklearn.cluster import KMeans +from sklearn.ensemble import RandomForestClassifier  data = np.load(DATA_FILE)  train = data['desc_tr'] -train_part = data['desc_sel'].T +train_part = data['desc_sel'].T[0:1000]  print("Computing KMeans with", train_part.shape[0], "keywords")  kmeans = KMeans(n_clusters=CLUSTER_CNT, random_state=0).fit(train_part)  print("Generating histograms")  histogram = np.zeros((train.shape[0], train.shape[1],CLUSTER_CNT)) +  for i in range(train.shape[0]):      for j in range(train.shape[1]):          histogram[i][j] = np.bincount(kmeans.predict(train[i][j].T),minlength=CLUSTER_CNT) -print(histogram.shape) +print("Keywords shape", histogram.shape, "\n") +print("Planting trees...") +clf = RandomForestClassifier() +clf.fit( +        histogram.reshape((histogram.shape[0]*histogram.shape[1], histogram.shape[2])), +        np.repeat(np.arange(histogram.shape[0]), histogram.shape[1])) + +print("Random forests created") + +print(clf.score( +        histogram.reshape((histogram.shape[0]*histogram.shape[1], histogram.shape[2])), +        np.repeat(np.arange(histogram.shape[0]), histogram.shape[1]))) -plt.hist(histogram[1][5]) -plt.show() -plt.hist(histogram[3][2]) -plt.show() -plt.hist(histogram[7][8]) -plt.show()  | 
