diff options
Diffstat (limited to 'classifier.py')
-rwxr-xr-x | classifier.py | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/classifier.py b/classifier.py new file mode 100755 index 0000000..b8e2b6a --- /dev/null +++ b/classifier.py @@ -0,0 +1,48 @@ +#!/usr/bin/python + +from absl import flags +from absl import logging + +import os +import sys +import numpy as np +import tensorflow as tf +import models + +import resnet_preprocessing + + +DEF_IMAGE_WIDTH = 320 +DEF_IMAGE_HEIGHT = 240 +DEF_WEIGHTS = 'weights.h5' + +flags.DEFINE_integer('image_width', DEF_IMAGE_WIDTH, '') +flags.DEFINE_integer('image_height', DEF_IMAGE_HEIGHT, '') +flags.DEFINE_string('weights', DEF_WEIGHTS, 'Weights of the model') +flags.DEFINE_integer('num_classes', 39, 'Number of classes thei weights were trained for') + +FLAGS = flags.FLAGS + +FLAGS(sys.argv) + +print(FLAGS.weights) + +classes = [ "adnetwork", "adobe", "airbnb", "amazon", "applecomputer", "applecomputer_scam", "bancosantander", "bankofamerica", "bnbankru", "bnpparibas", "chase", "craigslist", "dhl", "docusign", "dropbox", "facebook", "genericwebmailphishing", "godaddy", "google", "holding", "ingdirect", "linkedin", "microsoft", "microsoft_scam", "navyfederalcreditunion", "netflix", "orange", "paypal", "phpshell", "posteitaliane", "postmaster", "squarespace", "unicreditgroup", "visa", "vkontakte", "wellsfargo", "wetransfer", "windowslive", "yahoo"] + +model = models.ResNet50(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=FLAGS.num_classes) + +weights_file = os.path.join(FLAGS.weights) +model.load_weights(weights_file) + +image_bytes = tf.read_file('my_file.png') + +images_raw = resnet_preprocessing.preprocess_image(image_bytes, FLAGS.image_width, FLAGS.image_height, resize=True, is_training=True) + +images_expanded = tf.expand_dims(images_raw, 0) +predictions = model.predict_on_batch(images_expanded) + +for prediction in predictions: + largest_ind = np.argpartition(prediction, -5)[-5:] + + for i in largest_ind[np.argsort(-prediction[largest_ind])]: + print(classes[i]+" : \t\t"+str(prediction[i])) |