summaryrefslogtreecommitdiff
path: root/classifier.py
blob: b8e2b6a9ee17297de5047b4d8189ffa17664d06c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#!/usr/bin/python

from absl import flags
from absl import logging

import os
import sys
import numpy as np
import tensorflow as tf
import models

import resnet_preprocessing


DEF_IMAGE_WIDTH  = 320
DEF_IMAGE_HEIGHT = 240
DEF_WEIGHTS = 'weights.h5'

flags.DEFINE_integer('image_width', DEF_IMAGE_WIDTH, '')
flags.DEFINE_integer('image_height', DEF_IMAGE_HEIGHT, '')
flags.DEFINE_string('weights', DEF_WEIGHTS, 'Weights of the model')
flags.DEFINE_integer('num_classes', 39, 'Number of classes thei weights were trained for')

FLAGS = flags.FLAGS

FLAGS(sys.argv)

print(FLAGS.weights)

classes = [ "adnetwork", "adobe", "airbnb", "amazon", "applecomputer", "applecomputer_scam", "bancosantander", "bankofamerica", "bnbankru", "bnpparibas", "chase", "craigslist", "dhl", "docusign", "dropbox", "facebook", "genericwebmailphishing", "godaddy", "google", "holding", "ingdirect", "linkedin", "microsoft", "microsoft_scam", "navyfederalcreditunion", "netflix", "orange", "paypal", "phpshell", "posteitaliane", "postmaster", "squarespace", "unicreditgroup", "visa", "vkontakte", "wellsfargo", "wetransfer", "windowslive", "yahoo"]

model = models.ResNet50(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=FLAGS.num_classes)

weights_file = os.path.join(FLAGS.weights)
model.load_weights(weights_file)

image_bytes = tf.read_file('my_file.png')

images_raw = resnet_preprocessing.preprocess_image(image_bytes, FLAGS.image_width, FLAGS.image_height, resize=True, is_training=True) 

images_expanded = tf.expand_dims(images_raw, 0)
predictions = model.predict_on_batch(images_expanded)

for prediction in predictions:
    largest_ind = np.argpartition(prediction, -5)[-5:]

    for i in largest_ind[np.argsort(-prediction[largest_ind])]:
        print(classes[i]+" :   \t\t"+str(prediction[i]))