summaryrefslogtreecommitdiff
path: root/classifier.py
diff options
context:
space:
mode:
Diffstat (limited to 'classifier.py')
-rwxr-xr-xclassifier.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/classifier.py b/classifier.py
new file mode 100755
index 0000000..b8e2b6a
--- /dev/null
+++ b/classifier.py
@@ -0,0 +1,48 @@
+#!/usr/bin/python
+
+from absl import flags
+from absl import logging
+
+import os
+import sys
+import numpy as np
+import tensorflow as tf
+import models
+
+import resnet_preprocessing
+
+
+DEF_IMAGE_WIDTH = 320
+DEF_IMAGE_HEIGHT = 240
+DEF_WEIGHTS = 'weights.h5'
+
+flags.DEFINE_integer('image_width', DEF_IMAGE_WIDTH, '')
+flags.DEFINE_integer('image_height', DEF_IMAGE_HEIGHT, '')
+flags.DEFINE_string('weights', DEF_WEIGHTS, 'Weights of the model')
+flags.DEFINE_integer('num_classes', 39, 'Number of classes thei weights were trained for')
+
+FLAGS = flags.FLAGS
+
+FLAGS(sys.argv)
+
+print(FLAGS.weights)
+
+classes = [ "adnetwork", "adobe", "airbnb", "amazon", "applecomputer", "applecomputer_scam", "bancosantander", "bankofamerica", "bnbankru", "bnpparibas", "chase", "craigslist", "dhl", "docusign", "dropbox", "facebook", "genericwebmailphishing", "godaddy", "google", "holding", "ingdirect", "linkedin", "microsoft", "microsoft_scam", "navyfederalcreditunion", "netflix", "orange", "paypal", "phpshell", "posteitaliane", "postmaster", "squarespace", "unicreditgroup", "visa", "vkontakte", "wellsfargo", "wetransfer", "windowslive", "yahoo"]
+
+model = models.ResNet50(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=FLAGS.num_classes)
+
+weights_file = os.path.join(FLAGS.weights)
+model.load_weights(weights_file)
+
+image_bytes = tf.read_file('my_file.png')
+
+images_raw = resnet_preprocessing.preprocess_image(image_bytes, FLAGS.image_width, FLAGS.image_height, resize=True, is_training=True)
+
+images_expanded = tf.expand_dims(images_raw, 0)
+predictions = model.predict_on_batch(images_expanded)
+
+for prediction in predictions:
+ largest_ind = np.argpartition(prediction, -5)[-5:]
+
+ for i in largest_ind[np.argsort(-prediction[largest_ind])]:
+ print(classes[i]+" : \t\t"+str(prediction[i]))