From 0665b7fe0169a1b8126502f8945e4535d3c7c693 Mon Sep 17 00:00:00 2001 From: nunzip Date: Thu, 28 Feb 2019 00:19:40 +0000 Subject: Set state of train_test_split --- lenet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lenet.py b/lenet.py index 31664ee..a38f4e1 100644 --- a/lenet.py +++ b/lenet.py @@ -143,10 +143,10 @@ def mix_data(X_train, y_train, X_validation, y_validation, train_gen, tr_labels_ val_labels = val_labels_gen else: - X_train_gen, _, y_train_gen, _ = train_test_split(train_gen, tr_labels_gen, test_size=1-split, stratify=tr_labels_gen) - X_train_original, _, y_train_original, _ = train_test_split(X_train, y_train, test_size=split, stratify=y_train) - X_validation_gen, _, y_validation_gen, _ = train_test_split(val_gen, val_labels_gen, test_size=1-split, stratify=val_labels_gen) - X_validation_original, _, y_validation_original, _ = train_test_split(X_validation, y_validation, test_size=split, stratify=y_validation) + X_train_gen, _, y_train_gen, _ = train_test_split(train_gen, tr_labels_gen, test_size=1-split, random_state=0, stratify=tr_labels_gen) + X_train_original, _, y_train_original, _ = train_test_split(X_train, y_train, test_size=split, random_state=0, stratify=y_train) + X_validation_gen, _, y_validation_gen, _ = train_test_split(val_gen, val_labels_gen, test_size=1-split, random_state=0, stratify=val_labels_gen) + X_validation_original, _, y_validation_original, _ = train_test_split(X_validation, y_validation, test_size=split, random_state=0, stratify=y_validation) train_data = np.concatenate((X_train_gen, X_train_original), axis=0) train_labels = np.concatenate((y_train_gen, y_train_original), axis=0) val_data = np.concatenate((X_validation_gen, X_validation_original), axis=0) -- cgit v1.2.3