from __future__ import print_function import tensorflow.keras as keras import tensorflow as tf from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Dropout, Flatten from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D from tensorflow.keras import backend as K from tensorflow.keras import optimizers import matplotlib.pyplot as plt from tensorflow.keras.metrics import categorical_accuracy import numpy as np import random from sklearn.metrics import accuracy_score from sklearn.model_selection import train_test_split from classifier_metrics_impl import classifier_score_from_logits def import_mnist(): from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", reshape=False) X_train, y_train = mnist.train.images, mnist.train.labels X_validation, y_validation = mnist.validation.images, mnist.validation.labels X_test, y_test = mnist.test.images, mnist.test.labels y_train = keras.utils.to_categorical(y_train, 10) y_validation = keras.utils.to_categorical(y_validation, 10) y_test = keras.utils.to_categorical(y_test, 10) return X_train, y_train, X_validation, y_validation, X_test, y_test def plot_images(images, cls_true, cls_pred=None): assert len(images) == len(cls_true) == 9 img_shape = (32, 32) # Create figure with 3x3 sub-plots. fig, axes = plt.subplots(3, 3) fig.subplots_adjust(hspace=0.3, wspace=0.3) for i, ax in enumerate(axes.flat): # Plot image. ax.imshow(images[i].reshape(img_shape), cmap='binary') # Show true and predicted classes. if cls_pred is None: xlabel = "True: {0}".format(cls_true[i]) else: xlabel = "True: {0}, Pred: {1}".format(cls_true[i], cls_pred[i]) ax.set_xlabel(xlabel) ax.set_xticks([]) ax.set_yticks([]) plt.show() def plot_example_errors(y_pred, y_true, X_test): correct_prediction = np.equal(y_pred, y_true) incorrect = np.equal(correct_prediction, False) images = X_test[incorrect] cls_pred = y_pred[incorrect] cls_true = y_true[incorrect] plot_images(images=images[0:9], cls_true=cls_true[0:9], cls_pred=cls_pred[0:9].astype(np.int)) def get_lenet(shape): model = keras.Sequential() model.add(Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=shape)) model.add(AveragePooling2D()) model.add(Conv2D(filters=16, kernel_size=(3, 3), activation='relu')) model.add(AveragePooling2D()) model.add(Flatten()) model.add(Dense(units=120, activation='relu')) model.add(Dense(units=84, activation='relu')) #model.add(Dense(units=10, activation = 'softmax')) model.add(Dense(units=10, activation = 'relu')) return model def get_lenet_icp(shape): model = keras.Sequential() model.add(Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=(32,32,1))) model.add(AveragePooling2D()) model.add(Conv2D(filters=16, kernel_size=(3, 3), activation='relu')) model.add(AveragePooling2D()) model.add(Flatten()) model.add(Dense(units=120, activation='relu')) model.add(Dense(units=84, activation='relu')) model.add(Dense(units=10, activation = 'relu')) return model def plot_history(history, metric = None): # Plots the loss history of training and validation (if existing) # and a given metric if metric != None: fig, axes = plt.subplots(2,1) axes[0].plot(history.history[metric]) try: axes[0].plot(history.history['val_'+metric]) axes[0].legend(['Train', 'Val']) except: pass axes[0].set_title('{:s}'.format(metric)) axes[0].set_ylabel('{:s}'.format(metric)) axes[0].set_xlabel('Epoch') fig.subplots_adjust(hspace=0.5) axes[1].plot(history.history['loss']) try: axes[1].plot(history.history['val_loss']) axes[1].legend(['Train', 'Val']) except: pass axes[1].set_title('Model Loss') axes[1].set_ylabel('Loss') axes[1].set_xlabel('Epoch') else: plt.plot(history.history['loss']) try: plt.plot(history.history['val_loss']) plt.legend(['Train', 'Val']) except: pass plt.title('Model Loss') plt.ylabel('Loss') plt.xlabel('Epoch') def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, metrics=[categorical_accuracy], optimizer = None, keep_training = False, verbose=1): shape = (32, 32, 1) # Pad data to 32x32 (MNIST is 28x28) x_train = np.pad(x_train, ((0,0),(2,2),(2,2),(0,0)), 'constant') x_val = np.pad(x_val, ((0,0),(2,2),(2,2),(0,0)), 'constant') model = get_lenet(shape) if optimizer == None: optimizer = optimizers.SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True) model.compile(loss='categorical_crossentropy', metrics=metrics, optimizer=optimizer) if keep_training: model.load_weights('./weights.h5') history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=verbose, validation_data = (x_val, y_val)) model.save_weights('./model_gan.h5') plot_history(history, 'categorical_accuracy') plot_history(history) model.save_weights('./weights.h5') return model def test_classifier(model, x_test, y_true): x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant') y_pred = model.predict(x_test) logits = tf.convert_to_tensor(y_pred, dtype=tf.float32) inception_score = tf.keras.backend.eval(classifier_score_from_logits(logits)) y_pred = np.argmax(y_pred, axis=1) y_true = np.argmax(y_true, axis=1) plot_example_errors(y_pred, y_true, x_test) return accuracy_score(y_true, y_pred), inception_score def mix_data(X_train, y_train, X_validation, y_validation, train_gen, tr_labels_gen, val_gen, val_labels_gen, split=0): if split == 0: train_data = X_train train_labels = y_train val_data = X_validation val_labels = y_validation elif split == 1: train_data = train_gen train_labels = tr_labels_gen val_data = val_gen val_labels = val_labels_gen else: X_train_gen, _, y_train_gen, _ = train_test_split(train_gen, tr_labels_gen, test_size=1-split, random_state=0, stratify=tr_labels_gen) X_train_original, _, y_train_original, _ = train_test_split(X_train, y_train, test_size=split, random_state=0, stratify=y_train) X_validation_gen, _, y_validation_gen, _ = train_test_split(val_gen, val_labels_gen, test_size=1-split, random_state=0, stratify=val_labels_gen) X_validation_original, _, y_validation_original, _ = train_test_split(X_validation, y_validation, test_size=split, random_state=0, stratify=y_validation) train_data = np.concatenate((X_train_gen, X_train_original), axis=0) train_labels = np.concatenate((y_train_gen, y_train_original), axis=0) val_data = np.concatenate((X_validation_gen, X_validation_original), axis=0) val_labels = np.concatenate((y_validation_gen, y_validation_original), axis=0) return train_data, train_labels, val_data, val_labels # If file run directly, perform quick test if __name__ == '__main__': x_train, y_train, x_val, y_val, x_t, y_t = import_mnist() print(y_t.shape) model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=3) print(test_classifier(model, x_t, y_t))