from __future__ import print_function import tensorflow.keras as keras import tensorflow as tf from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Dropout, Flatten from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D from tensorflow.keras import backend as K from tensorflow.keras import optimizers import matplotlib.pyplot as plt from tensorflow.keras.metrics import categorical_accuracy import numpy as np import random from sklearn.metrics import accuracy_score from sklearn.metrics import confusion_matrix from sklearn.model_selection import train_test_split from sklearn.decomposition import PCA from classifier_metrics_impl import classifier_score_from_logits from sklearn.utils import shuffle from sklearn.manifold import TSNE def import_mnist(): from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", reshape=False) X_train, y_train = mnist.train.images, mnist.train.labels X_validation, y_validation = mnist.validation.images, mnist.validation.labels X_test, y_test = mnist.test.images, mnist.test.labels y_train = keras.utils.to_categorical(y_train, 10) y_validation = keras.utils.to_categorical(y_validation, 10) y_test = keras.utils.to_categorical(y_test, 10) return X_train, y_train, X_validation, y_validation, X_test, y_test def plot_images(images, cls_true, cls_pred=None): assert len(images) == len(cls_true) == 9 img_shape = (32, 32) # Create figure with 3x3 sub-plots. fig, axes = plt.subplots(3, 3) fig.subplots_adjust(hspace=0.3, wspace=0.3) for i, ax in enumerate(axes.flat): # Plot image. ax.imshow(images[i].reshape(img_shape), cmap='binary') # Show true and predicted classes. if cls_pred is None: xlabel = "True: {0}".format(cls_true[i]) else: xlabel = "True: {0}, Pred: {1}".format(cls_true[i], cls_pred[i]) ax.set_xlabel(xlabel) ax.set_xticks([]) ax.set_yticks([]) plt.show() def plot_example_errors(y_pred, y_true, X_test): correct_prediction = np.equal(y_pred, y_true) incorrect = np.equal(correct_prediction, False) images = X_test[incorrect] cls_pred = y_pred[incorrect] cls_true = y_true[incorrect] plot_images(images=images[0:9], cls_true=cls_true[0:9], cls_pred=cls_pred[0:9].astype(np.int)) def get_lenet(shape): model = keras.Sequential() model.add(Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=shape)) model.add(AveragePooling2D()) model.add(Conv2D(filters=16, kernel_size=(3, 3), activation='relu')) model.add(AveragePooling2D()) model.add(Flatten()) model.add(Dense(units=120, activation='relu')) model.add(Dense(units=84, activation='relu')) model.add(Dense(units=10, activation = 'softmax')) return model def get_lenet_icp(shape): model = keras.Sequential() model.add(Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=(32,32,1))) model.add(AveragePooling2D()) model.add(Conv2D(filters=16, kernel_size=(3, 3), activation='relu')) model.add(AveragePooling2D()) model.add(Flatten()) model.add(Dense(units=120, activation='relu')) model.add(Dense(units=84, activation='relu')) model.add(Dense(units=10, activation = 'relu')) return model def plot_history(history, metric = None): # Plots the loss history of training and validation (if existing) # and a given metric if metric != None: fig, axes = plt.subplots(2,1) axes[0].plot(history.history[metric]) try: axes[0].plot(history.history['val_'+metric]) axes[0].legend(['Train', 'Val']) except: pass axes[0].set_title('{:s}'.format(metric)) axes[0].set_ylabel('{:s}'.format(metric)) axes[0].set_xlabel('Epoch') fig.subplots_adjust(hspace=0.5) axes[1].plot(history.history['loss']) try: axes[1].plot(history.history['val_loss']) axes[1].legend(['Train', 'Val']) except: pass axes[1].set_title('Model Loss') axes[1].set_ylabel('Loss') axes[1].set_xlabel('Epoch') else: plt.plot(history.history['loss']) try: plt.plot(history.history['val_loss']) plt.legend(['Train', 'Val']) except: pass plt.title('Model Loss') plt.ylabel('Loss') plt.xlabel('Epoch') def train_classifier(x_train, y_train, x_val, y_val, batch_size=128, epochs=100, metrics=[categorical_accuracy], optimizer = None, keep_training = False, verbose=1): shape = (32, 32, 1) # Pad data to 32x32 (MNIST is 28x28) x_train = np.pad(x_train, ((0,0),(2,2),(2,2),(0,0)), 'constant') x_val = np.pad(x_val, ((0,0),(2,2),(2,2),(0,0)), 'constant') model = get_lenet(shape) if optimizer == None: optimizer = optimizers.SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True) model.compile(loss='categorical_crossentropy', metrics=metrics, optimizer=optimizer) if keep_training: model.load_weights('./weights.h5') history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=verbose, validation_data = (x_val, y_val)) model.save_weights('./model_gan.h5') plot_history(history, 'categorical_accuracy') plot_history(history) model.save_weights('./weights.h5') return model def test_classifier(model, x_test, y_true, conf_mat=False, pca=False, tsne=False): x_test = np.pad(x_test, ((0,0),(2,2),(2,2),(0,0)), 'constant') logits = model.predict(x_test) tf_logits = tf.convert_to_tensor(logits, dtype=tf.float32) inception_score = tf.keras.backend.eval(classifier_score_from_logits(tf_logits)) y_pred = np.argmax(logits, axis=1) y_true = np.argmax(y_true, axis=1) plot_example_errors(y_pred, y_true, x_test) cm = confusion_matrix(y_true, y_pred) if conf_mat: plt.matshow(cm, cmap='Blues') plt.colorbar() plt.ylabel('Actual') plt.xlabel('Predicted') plt.show() if pca: set_pca = PCA(n_components=2) pca_rep = set_pca.fit_transform(logits) pca_rep, y_tmp = shuffle(pca_rep, y_true, random_state=0) plt.scatter(pca_rep[:1000, 0], pca_rep[:1000, 1], c=y_true[:1000], edgecolor='none', alpha=0.5, cmap=plt.cm.get_cmap('Paired', 10)) plt.xlabel('component 1') plt.ylabel('component 2') plt.colorbar(); plt.show() if tsne: tsne = TSNE(n_components=2, random_state=0) components = tsne.fit_transform(logits) print(components.shape) components, y_tmp = shuffle(components, y_true, random_state=0) plt.scatter(components[:1000, 0], components[:1000, 1], c=y_true[:1000], edgecolor='none', alpha=0.5, cmap=plt.cm.get_cmap('Paired', 10)) plt.xlabel('component 1') plt.ylabel('component 2') plt.colorbar(); plt.show() return accuracy_score(y_true, y_pred), inception_score def mix_data(X_train, y_train, X_validation, y_validation, train_gen, tr_labels_gen, val_gen, val_labels_gen, split=0): if split == 0: train_data = X_train train_labels = y_train val_data = X_validation val_labels = y_validation elif split == 1: train_data = train_gen train_labels = tr_labels_gen val_data = val_gen val_labels = val_labels_gen else: X_train_gen, _, y_train_gen, _ = train_test_split(train_gen, tr_labels_gen, test_size=1-split, random_state=0, stratify=tr_labels_gen) X_train_original, _, y_train_original, _ = train_test_split(X_train, y_train, test_size=split, random_state=0, stratify=y_train) X_validation_gen, _, y_validation_gen, _ = train_test_split(val_gen, val_labels_gen, test_size=1-split, random_state=0, stratify=val_labels_gen) X_validation_original, _, y_validation_original, _ = train_test_split(X_validation, y_validation, test_size=split, random_state=0, stratify=y_validation) train_data = np.concatenate((X_train_gen, X_train_original), axis=0) train_labels = np.concatenate((y_train_gen, y_train_original), axis=0) val_data = np.concatenate((X_validation_gen, X_validation_original), axis=0) val_labels = np.concatenate((y_validation_gen, y_validation_original), axis=0) return train_data, train_labels, val_data, val_labels # If file run directly, perform quick test if __name__ == '__main__': x_train, y_train, x_val, y_val, x_t, y_t = import_mnist() print(y_t.shape) model = train_classifier(x_train[:100], y_train[:100], x_val, y_val, epochs=3) print(test_classifier(model, x_t, y_t, pca=False, tsne=True))