aboutsummaryrefslogtreecommitdiff
path: root/lenet.py
diff options
context:
space:
mode:
Diffstat (limited to 'lenet.py')
-rw-r--r--lenet.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/lenet.py b/lenet.py
index 495deaf..ff042f1 100644
--- a/lenet.py
+++ b/lenet.py
@@ -11,6 +11,7 @@ from tensorflow.keras.metrics import categorical_accuracy
import numpy as np
import random
from sklearn.metrics import accuracy_score
+from sklearn.model_selection import train_test_split
def import_mnist():
from tensorflow.examples.tutorials.mnist import input_data
@@ -126,6 +127,32 @@ def test_classifier(model, x_test, y_true):
print("Test acc:", accuracy_score(y_true, y_pred))
plot_example_errors(y_pred, y_true, x_test)
+def mix_data(split=0, X_train, y_train, X_validation, y_validation, train_gen, tr_labels_gen, val_gen, val_labels_gen):
+
+ if split == 0:
+ train_data = X_train
+ train_labels = y_train
+ val_data = X_validation
+ val_labels = y_validation
+
+ elif split == 1:
+ train_data = train_gen
+ train_labels = tr_labels_gen
+ val_data = val_gen
+ val_labels = val_labels_gen
+
+ else:
+ X_train_gen, _, y_train_gen, _ = train_test_split(train_gen, tr_labels_gen, test_size=1-split, stratify=tr_labels_gen)
+ X_train_original, _, y_train_original, _ = train_test_split(X_train, y_train, test_size=split, stratify=y_train)
+ X_validation_gen, _, y_validation_gen, _ = train_test_split(val_gen, val_labels_gen, test_size=1-split, stratify=val_labels_gen)
+ X_validation_original, _, y_validation_original, _ = train_test_split(X_validation, y_validation, test_size=split, stratify=y_validation)
+ train_data = np.concatenate((X_train_gen, X_train_original), axis=0)
+ train_labels = np.concatenate((y_train_gen, y_train_original), axis=0)
+ val_data = np.concatenate((X_validation_gen, X_validation_original), axis=0)
+ val_labels = np.concatenate((y_validation_gen, y_validation_original), axis=0)
+
+ return train_data, train_labels, val_data, val_labels
+
# If file run directly, perform quick test
if __name__ == '__main__':
x_train, y_train, x_val, y_val, x_t, y_t = import_mnist()