summaryrefslogtreecommitdiff
path: root/classifier-logo.py
diff options
context:
space:
mode:
Diffstat (limited to 'classifier-logo.py')
-rwxr-xr-xclassifier-logo.py51
1 files changed, 51 insertions, 0 deletions
diff --git a/classifier-logo.py b/classifier-logo.py
new file mode 100755
index 0000000..87e5779
--- /dev/null
+++ b/classifier-logo.py
@@ -0,0 +1,51 @@
+#!/usr/bin/python
+
+from absl import flags
+from absl import logging
+
+import os
+import sys
+import numpy as np
+import tensorflow as tf
+import models
+import resnet_preprocessing
+from tensorflow.keras.utils import plot_model
+
+
+DEF_IMAGE_WIDTH = None
+DEF_IMAGE_HEIGHT = None
+DEF_WEIGHTS = 'weights.h5'
+
+flags.DEFINE_integer('image_width', DEF_IMAGE_WIDTH, '')
+flags.DEFINE_integer('image_height', DEF_IMAGE_HEIGHT, '')
+flags.DEFINE_string('weights', DEF_WEIGHTS, 'Weights of the model')
+flags.DEFINE_integer('num_classes', 100, 'Number of classes thei weights were trained for')
+
+FLAGS = flags.FLAGS
+
+FLAGS(sys.argv)
+
+print(FLAGS.weights)
+
+classes = [ "absa_logo", "adobe_logo", "airbnb_logo", "alibaba_logo", "amazon_logo", "americanas_logo", "americanexpress_logo", "aol_logo", "apple_logo", "argenta_logo", "att_logo", "bancodechile_logo", "bancodecredito_logo", "bancodobrasil_logo", "bancosantander_logo", "bankofamerica_logo", "barclaysuk_logo", "bestchangecom_logo", "bet365_logo", "binance_logo", "blockchain_logo", "bnpparibas_logo", "bradesco_logo", "britishtelecom_logo", "caixabrazil_logo", "canadapharmacy_logo", "capitalone_logo", "casasbahia_logo", "chase_logo", "cibc_logo", "citicorp_logo", "coinbase_logo", "dailymirror_logo", "dhl_logo", "docusign_logo", "dropbox_logo", "ebay_logo", "ethereum_logo", "facebook_logo", "federalexpress_logo", "fonbetru_logo", "freefr_logo", "genericbankfraud_logo", "godaddy_logo", "google_logo", "grandlisboamacau_logo", "halkbankas_logo", "hmrevenuecustoms_logo", "hsbceub_logo", "impotsgouvfr_logo", "inggroup_logo", "instagram_logo", "interac_logo", "itauunibanco_logo", "lacaixaes_logo", "lassurancemaladie_logo", "lasvegassands_logo", "linkedin_logo", "lloydsbank_logo", "logos.txt", "luno_logo", "mcafee_logo", "mercadolibre_logo", "mercadopago_logo", "metrobank_uk_logo", "microsoft_logo", "mostbet_logo", "myetherwallet_logo", "nationalaustraliabank_logo", "natwest_logo", "netease_logo", "netflix_logo", "netseu_logo", "orange_logo", "ourtime_logo", "ovh_logo", "paypal_logo", "phpshell_logo", "posteitaliane_logo", "postmaster_logo", "rayban_logo", "rbc_logo", "scotiabank_logo", "standardchartered_logo", "steam_logo", "suntrust_logo", "swisscom_logo", "torontodominion_logo", "uber_logo", "unicredit_logo", "usaa_logo", "usbank_logo", "visa_logo", "vkontakte_logo", "walmart_logo", "wellsfargo_logo", "wetransfer_logo", "whatsapp_logo", "xfinity_logo", "yahoo_logo" ]
+
+model = models.get_logo_model(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=FLAGS.num_classes)
+
+weights_file = os.path.join(FLAGS.weights)
+model.load_weights(weights_file)
+model.save_weights('new.hdf5')
+
+image_bytes = tf.read_file('my_logo.png')
+
+images_raw = resnet_preprocessing.preprocess_image(image_bytes, FLAGS.image_width, FLAGS.image_height, resize=False, is_training=True)
+
+#plot_model(model, to_file='model.pdf', show_layer_names=False, show_shapes=True)
+
+images_expanded = tf.expand_dims(images_raw, 0)
+predictions = model.predict_on_batch(images_expanded)
+
+for prediction in predictions:
+ largest_ind = np.argpartition(prediction, -5)[-5:]
+
+ for i in largest_ind[np.argsort(-prediction[largest_ind])]:
+ print(classes[i]+" : \t\t"+str(prediction[i]))