diff options
author | nunzip <np.scarh@gmail.com> | 2019-02-27 23:31:31 +0000 |
---|---|---|
committer | nunzip <np.scarh@gmail.com> | 2019-02-27 23:31:31 +0000 |
commit | 5c14a69666560c1e63d69cf6e758119ae09a47c0 (patch) | |
tree | d39669d0498000eae4ee276e1257a24a11394714 /lenet.py | |
parent | d0f97c01830b1018b1327461c1503bc1cf316eae (diff) | |
download | e4-gan-5c14a69666560c1e63d69cf6e758119ae09a47c0.tar.gz e4-gan-5c14a69666560c1e63d69cf6e758119ae09a47c0.tar.bz2 e4-gan-5c14a69666560c1e63d69cf6e758119ae09a47c0.zip |
Add mix_data function
Diffstat (limited to 'lenet.py')
-rw-r--r-- | lenet.py | 27 |
1 files changed, 27 insertions, 0 deletions
@@ -11,6 +11,7 @@ from tensorflow.keras.metrics import categorical_accuracy import numpy as np import random from sklearn.metrics import accuracy_score +from sklearn.model_selection import train_test_split def import_mnist(): from tensorflow.examples.tutorials.mnist import input_data @@ -126,6 +127,32 @@ def test_classifier(model, x_test, y_true): print("Test acc:", accuracy_score(y_true, y_pred)) plot_example_errors(y_pred, y_true, x_test) +def mix_data(split=0, X_train, y_train, X_validation, y_validation, train_gen, tr_labels_gen, val_gen, val_labels_gen): + + if split == 0: + train_data = X_train + train_labels = y_train + val_data = X_validation + val_labels = y_validation + + elif split == 1: + train_data = train_gen + train_labels = tr_labels_gen + val_data = val_gen + val_labels = val_labels_gen + + else: + X_train_gen, _, y_train_gen, _ = train_test_split(train_gen, tr_labels_gen, test_size=1-split, stratify=tr_labels_gen) + X_train_original, _, y_train_original, _ = train_test_split(X_train, y_train, test_size=split, stratify=y_train) + X_validation_gen, _, y_validation_gen, _ = train_test_split(val_gen, val_labels_gen, test_size=1-split, stratify=val_labels_gen) + X_validation_original, _, y_validation_original, _ = train_test_split(X_validation, y_validation, test_size=split, stratify=y_validation) + train_data = np.concatenate((X_train_gen, X_train_original), axis=0) + train_labels = np.concatenate((y_train_gen, y_train_original), axis=0) + val_data = np.concatenate((X_validation_gen, X_validation_original), axis=0) + val_labels = np.concatenate((y_validation_gen, y_validation_original), axis=0) + + return train_data, train_labels, val_data, val_labels + # If file run directly, perform quick test if __name__ == '__main__': x_train, y_train, x_val, y_val, x_t, y_t = import_mnist() |