summaryrefslogtreecommitdiff
path: root/classifier-logo.py
blob: 87e577971f03d750450295859e55b115934a66cc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/bin/python

from absl import flags
from absl import logging

import os
import sys
import numpy as np
import tensorflow as tf
import models
import resnet_preprocessing
from tensorflow.keras.utils import plot_model


DEF_IMAGE_WIDTH  = None
DEF_IMAGE_HEIGHT = None
DEF_WEIGHTS = 'weights.h5'

flags.DEFINE_integer('image_width', DEF_IMAGE_WIDTH, '')
flags.DEFINE_integer('image_height', DEF_IMAGE_HEIGHT, '')
flags.DEFINE_string('weights', DEF_WEIGHTS, 'Weights of the model')
flags.DEFINE_integer('num_classes', 100, 'Number of classes thei weights were trained for')

FLAGS = flags.FLAGS

FLAGS(sys.argv)

print(FLAGS.weights)

classes = [ "absa_logo", "adobe_logo", "airbnb_logo", "alibaba_logo", "amazon_logo", "americanas_logo", "americanexpress_logo", "aol_logo", "apple_logo", "argenta_logo", "att_logo", "bancodechile_logo", "bancodecredito_logo", "bancodobrasil_logo", "bancosantander_logo", "bankofamerica_logo", "barclaysuk_logo", "bestchangecom_logo", "bet365_logo", "binance_logo", "blockchain_logo", "bnpparibas_logo", "bradesco_logo", "britishtelecom_logo", "caixabrazil_logo", "canadapharmacy_logo", "capitalone_logo", "casasbahia_logo", "chase_logo", "cibc_logo", "citicorp_logo", "coinbase_logo", "dailymirror_logo", "dhl_logo", "docusign_logo", "dropbox_logo", "ebay_logo", "ethereum_logo", "facebook_logo", "federalexpress_logo", "fonbetru_logo", "freefr_logo", "genericbankfraud_logo", "godaddy_logo", "google_logo", "grandlisboamacau_logo", "halkbankas_logo", "hmrevenuecustoms_logo", "hsbceub_logo", "impotsgouvfr_logo", "inggroup_logo", "instagram_logo", "interac_logo", "itauunibanco_logo", "lacaixaes_logo", "lassurancemaladie_logo", "lasvegassands_logo", "linkedin_logo", "lloydsbank_logo", "logos.txt", "luno_logo", "mcafee_logo", "mercadolibre_logo", "mercadopago_logo", "metrobank_uk_logo", "microsoft_logo", "mostbet_logo", "myetherwallet_logo", "nationalaustraliabank_logo", "natwest_logo", "netease_logo", "netflix_logo", "netseu_logo", "orange_logo", "ourtime_logo", "ovh_logo", "paypal_logo", "phpshell_logo", "posteitaliane_logo", "postmaster_logo", "rayban_logo", "rbc_logo", "scotiabank_logo", "standardchartered_logo", "steam_logo", "suntrust_logo", "swisscom_logo", "torontodominion_logo", "uber_logo", "unicredit_logo", "usaa_logo", "usbank_logo", "visa_logo", "vkontakte_logo", "walmart_logo", "wellsfargo_logo", "wetransfer_logo", "whatsapp_logo", "xfinity_logo", "yahoo_logo" ]

model = models.get_logo_model(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=FLAGS.num_classes)

weights_file = os.path.join(FLAGS.weights)
model.load_weights(weights_file)
model.save_weights('new.hdf5')

image_bytes = tf.read_file('my_logo.png')

images_raw = resnet_preprocessing.preprocess_image(image_bytes, FLAGS.image_width, FLAGS.image_height, resize=False, is_training=True) 

#plot_model(model, to_file='model.pdf', show_layer_names=False, show_shapes=True)

images_expanded = tf.expand_dims(images_raw, 0)
predictions = model.predict_on_batch(images_expanded)

for prediction in predictions:
    largest_ind = np.argpartition(prediction, -5)[-5:]

    for i in largest_ind[np.argsort(-prediction[largest_ind])]:
        print(classes[i]+" :   \t\t"+str(prediction[i]))