From 791ff2ad0e32bf1c916f6d4119f0e5eb1ab9a35c Mon Sep 17 00:00:00 2001 From: nunzip Date: Sat, 9 Mar 2019 16:22:31 +0000 Subject: Produce variable output gen --- cgan.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) (limited to 'cgan.py') diff --git a/cgan.py b/cgan.py index d27b11b..0e3dacd 100755 --- a/cgan.py +++ b/cgan.py @@ -194,19 +194,23 @@ class CGAN(): fig.savefig("images/%d.png" % epoch) plt.close() - def generate_data(self): - noise_train = np.random.normal(0, 1, (55000, 100)) - noise_test = np.random.normal(0, 1, (10000, 100)) - noise_val = np.random.normal(0, 1, (5000, 100)) - - labels_train = np.zeros(55000).reshape(-1, 1) - labels_test = np.zeros(10000).reshape(-1, 1) - labels_val = np.zeros(5000).reshape(-1, 1) - + def generate_data(self, split): + train_size = int((55000*100/split)-55000) + val_size = int(train_size/11) + test_size = 2*val_size + + noise_train = np.random.normal(0, 1, (train_size, 100)) + noise_test = np.random.normal(0, 1, (test_size, 100)) + noise_val = np.random.normal(0, 1, (val_size, 100)) + + labels_train = np.zeros(train_size).reshape(-1, 1) + labels_test = np.zeros(test_size).reshape(-1, 1) + labels_val = np.zeros(val_size).reshape(-1, 1) + for i in range(10): - labels_train[i*5500:-1] = i - labels_test[i*1000:-1] = i - labels_val[i*500:-1] = i + labels_train[i*int(train_size/10):-1] = i + labels_test[i*int(test_size/10):-1] = i + labels_val[i*int(val_size/10):-1] = i train_data = self.generator.predict([noise_train, labels_train]) test_data = self.generator.predict([noise_test, labels_test]) -- cgit v1.2.3-54-g00ecf