From a7cb76b4131a7b5b142ac26aa2d47f7e8097c0db Mon Sep 17 00:00:00 2001 From: nunzip Date: Wed, 27 Feb 2019 20:25:24 +0000 Subject: Add validation output --- cgan.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/cgan.py b/cgan.py index 7bef77c..17aa367 100644 --- a/cgan.py +++ b/cgan.py @@ -193,17 +193,23 @@ class CGAN(): plt.close() def generate_data(self): - noise_train = np.random.normal(0, 1, (60000, 100)) + noise_train = np.random.normal(0, 1, (55000, 100)) noise_test = np.random.normal(0, 1, (10000, 100)) + noise_val = np.random.normal(0, 1, (5000, 100)) - labels_train = np.zeros(60000).reshape(-1, 1) + labels_train = np.zeros(55000).reshape(-1, 1) labels_test = np.zeros(10000).reshape(-1, 1) + labels_val = np.zeros(5000).reshape(-1, 1) + for i in range(10): - labels_train[i*600:] = i - labels_test[i*100:] = i + labels_train[i*5500:] = i + labels_test[i*1000:] = i + labels_val[i*500:] = i + train_data = self.generator.predict([noise_train, labels_train]) test_data = self.generator.predict([noise_test, labels_test]) - return train_data, test_data, labels_train, labels_test + val_data = self.generator.predict([noise_train, labels_val]) + return train_data, test_data, val_data, labels_train, labels_test, labels_val ''' -- cgit v1.2.3