#!/usr/bin/python from absl import flags from absl import logging import os import sys import numpy as np import tensorflow as tf import models import resnet_preprocessing from tensorflow.keras.utils import plot_model DEF_IMAGE_WIDTH = None DEF_IMAGE_HEIGHT = None DEF_WEIGHTS = 'weights.h5' flags.DEFINE_integer('image_width', DEF_IMAGE_WIDTH, '') flags.DEFINE_integer('image_height', DEF_IMAGE_HEIGHT, '') flags.DEFINE_string('weights', DEF_WEIGHTS, 'Weights of the model') flags.DEFINE_integer('num_classes', 100, 'Number of classes thei weights were trained for') FLAGS = flags.FLAGS FLAGS(sys.argv) print(FLAGS.weights) classes = [ "absa_logo", "adobe_logo", "airbnb_logo", "alibaba_logo", "amazon_logo", "americanas_logo", "americanexpress_logo", "aol_logo", "apple_logo", "argenta_logo", "att_logo", "bancodechile_logo", "bancodecredito_logo", "bancodobrasil_logo", "bancosantander_logo", "bankofamerica_logo", "barclaysuk_logo", "bestchangecom_logo", "bet365_logo", "binance_logo", "blockchain_logo", "bnpparibas_logo", "bradesco_logo", "britishtelecom_logo", "caixabrazil_logo", "canadapharmacy_logo", "capitalone_logo", "casasbahia_logo", "chase_logo", "cibc_logo", "citicorp_logo", "coinbase_logo", "dailymirror_logo", "dhl_logo", "docusign_logo", "dropbox_logo", "ebay_logo", "ethereum_logo", "facebook_logo", "federalexpress_logo", "fonbetru_logo", "freefr_logo", "genericbankfraud_logo", "godaddy_logo", "google_logo", "grandlisboamacau_logo", "halkbankas_logo", "hmrevenuecustoms_logo", "hsbceub_logo", "impotsgouvfr_logo", "inggroup_logo", "instagram_logo", "interac_logo", "itauunibanco_logo", "lacaixaes_logo", "lassurancemaladie_logo", "lasvegassands_logo", "linkedin_logo", "lloydsbank_logo", "logos.txt", "luno_logo", "mcafee_logo", "mercadolibre_logo", "mercadopago_logo", "metrobank_uk_logo", "microsoft_logo", "mostbet_logo", "myetherwallet_logo", "nationalaustraliabank_logo", "natwest_logo", "netease_logo", "netflix_logo", "netseu_logo", "orange_logo", "ourtime_logo", "ovh_logo", "paypal_logo", "phpshell_logo", "posteitaliane_logo", "postmaster_logo", "rayban_logo", "rbc_logo", "scotiabank_logo", "standardchartered_logo", "steam_logo", "suntrust_logo", "swisscom_logo", "torontodominion_logo", "uber_logo", "unicredit_logo", "usaa_logo", "usbank_logo", "visa_logo", "vkontakte_logo", "walmart_logo", "wellsfargo_logo", "wetransfer_logo", "whatsapp_logo", "xfinity_logo", "yahoo_logo" ] model = models.get_logo_model(width=FLAGS.image_width, height=FLAGS.image_height, num_classes=FLAGS.num_classes) weights_file = os.path.join(FLAGS.weights) model.load_weights(weights_file) model.save_weights('new.hdf5') image_bytes = tf.read_file('my_logo.png') images_raw = resnet_preprocessing.preprocess_image(image_bytes, FLAGS.image_width, FLAGS.image_height, resize=False, is_training=True) #plot_model(model, to_file='model.pdf', show_layer_names=False, show_shapes=True) images_expanded = tf.expand_dims(images_raw, 0) predictions = model.predict_on_batch(images_expanded) for prediction in predictions: largest_ind = np.argpartition(prediction, -5)[-5:] for i in largest_ind[np.argsort(-prediction[largest_ind])]: print(classes[i]+" : \t\t"+str(prediction[i]))