From 5b9daa1503dce4228eb0c78f9b0a17f018cb8114 Mon Sep 17 00:00:00 2001 From: Vasil Zlatanov Date: Mon, 4 Feb 2019 18:08:22 +0000 Subject: Complete question 2 by adding RF --- evaluate.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/evaluate.py b/evaluate.py index b4d4a33..9b26920 100755 --- a/evaluate.py +++ b/evaluate.py @@ -4,32 +4,38 @@ DATA_FILE = 'data.npz' CLUSTER_CNT = 256 -KMEAN_PART = 33 import numpy as np import matplotlib.pyplot as plt from sklearn.cluster import KMeans +from sklearn.ensemble import RandomForestClassifier data = np.load(DATA_FILE) train = data['desc_tr'] -train_part = data['desc_sel'].T +train_part = data['desc_sel'].T[0:1000] print("Computing KMeans with", train_part.shape[0], "keywords") kmeans = KMeans(n_clusters=CLUSTER_CNT, random_state=0).fit(train_part) print("Generating histograms") histogram = np.zeros((train.shape[0], train.shape[1],CLUSTER_CNT)) + for i in range(train.shape[0]): for j in range(train.shape[1]): histogram[i][j] = np.bincount(kmeans.predict(train[i][j].T),minlength=CLUSTER_CNT) -print(histogram.shape) +print("Keywords shape", histogram.shape, "\n") +print("Planting trees...") +clf = RandomForestClassifier() +clf.fit( + histogram.reshape((histogram.shape[0]*histogram.shape[1], histogram.shape[2])), + np.repeat(np.arange(histogram.shape[0]), histogram.shape[1])) + +print("Random forests created") + +print(clf.score( + histogram.reshape((histogram.shape[0]*histogram.shape[1], histogram.shape[2])), + np.repeat(np.arange(histogram.shape[0]), histogram.shape[1]))) -plt.hist(histogram[1][5]) -plt.show() -plt.hist(histogram[3][2]) -plt.show() -plt.hist(histogram[7][8]) -plt.show() -- cgit v1.2.3-54-g00ecf