aboutsummaryrefslogtreecommitdiff
path: root/lenet.py
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-02-28 00:19:40 +0000
committernunzip <np.scarh@gmail.com>2019-02-28 00:19:40 +0000
commit0665b7fe0169a1b8126502f8945e4535d3c7c693 (patch)
tree342b59a2fdc173bc38a9739795c6f456e17d2060 /lenet.py
parentcbb551537a2505d8f189a4faf9e8c67fe1753d47 (diff)
downloade4-gan-0665b7fe0169a1b8126502f8945e4535d3c7c693.tar.gz
e4-gan-0665b7fe0169a1b8126502f8945e4535d3c7c693.tar.bz2
e4-gan-0665b7fe0169a1b8126502f8945e4535d3c7c693.zip
Set state of train_test_split
Diffstat (limited to 'lenet.py')
-rw-r--r--lenet.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/lenet.py b/lenet.py
index 31664ee..a38f4e1 100644
--- a/lenet.py
+++ b/lenet.py
@@ -143,10 +143,10 @@ def mix_data(X_train, y_train, X_validation, y_validation, train_gen, tr_labels_
val_labels = val_labels_gen
else:
- X_train_gen, _, y_train_gen, _ = train_test_split(train_gen, tr_labels_gen, test_size=1-split, stratify=tr_labels_gen)
- X_train_original, _, y_train_original, _ = train_test_split(X_train, y_train, test_size=split, stratify=y_train)
- X_validation_gen, _, y_validation_gen, _ = train_test_split(val_gen, val_labels_gen, test_size=1-split, stratify=val_labels_gen)
- X_validation_original, _, y_validation_original, _ = train_test_split(X_validation, y_validation, test_size=split, stratify=y_validation)
+ X_train_gen, _, y_train_gen, _ = train_test_split(train_gen, tr_labels_gen, test_size=1-split, random_state=0, stratify=tr_labels_gen)
+ X_train_original, _, y_train_original, _ = train_test_split(X_train, y_train, test_size=split, random_state=0, stratify=y_train)
+ X_validation_gen, _, y_validation_gen, _ = train_test_split(val_gen, val_labels_gen, test_size=1-split, random_state=0, stratify=val_labels_gen)
+ X_validation_original, _, y_validation_original, _ = train_test_split(X_validation, y_validation, test_size=split, random_state=0, stratify=y_validation)
train_data = np.concatenate((X_train_gen, X_train_original), axis=0)
train_labels = np.concatenate((y_train_gen, y_train_original), axis=0)
val_data = np.concatenate((X_validation_gen, X_validation_original), axis=0)