aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-03-09 16:22:31 +0000
committernunzip <np.scarh@gmail.com>2019-03-09 16:22:31 +0000
commit791ff2ad0e32bf1c916f6d4119f0e5eb1ab9a35c (patch)
tree073b834729d9e70c3156632d6ad432b7f6014ad6
parent3adb475617e8dd8e53335e834083e6c5348110a5 (diff)
downloade4-gan-791ff2ad0e32bf1c916f6d4119f0e5eb1ab9a35c.tar.gz
e4-gan-791ff2ad0e32bf1c916f6d4119f0e5eb1ab9a35c.tar.bz2
e4-gan-791ff2ad0e32bf1c916f6d4119f0e5eb1ab9a35c.zip
Produce variable output gen
-rwxr-xr-xcgan.py28
1 files changed, 16 insertions, 12 deletions
diff --git a/cgan.py b/cgan.py
index d27b11b..0e3dacd 100755
--- a/cgan.py
+++ b/cgan.py
@@ -194,19 +194,23 @@ class CGAN():
fig.savefig("images/%d.png" % epoch)
plt.close()
- def generate_data(self):
- noise_train = np.random.normal(0, 1, (55000, 100))
- noise_test = np.random.normal(0, 1, (10000, 100))
- noise_val = np.random.normal(0, 1, (5000, 100))
-
- labels_train = np.zeros(55000).reshape(-1, 1)
- labels_test = np.zeros(10000).reshape(-1, 1)
- labels_val = np.zeros(5000).reshape(-1, 1)
-
+ def generate_data(self, split):
+ train_size = int((55000*100/split)-55000)
+ val_size = int(train_size/11)
+ test_size = 2*val_size
+
+ noise_train = np.random.normal(0, 1, (train_size, 100))
+ noise_test = np.random.normal(0, 1, (test_size, 100))
+ noise_val = np.random.normal(0, 1, (val_size, 100))
+
+ labels_train = np.zeros(train_size).reshape(-1, 1)
+ labels_test = np.zeros(test_size).reshape(-1, 1)
+ labels_val = np.zeros(val_size).reshape(-1, 1)
+
for i in range(10):
- labels_train[i*5500:-1] = i
- labels_test[i*1000:-1] = i
- labels_val[i*500:-1] = i
+ labels_train[i*int(train_size/10):-1] = i
+ labels_test[i*int(test_size/10):-1] = i
+ labels_val[i*int(val_size/10):-1] = i
train_data = self.generator.predict([noise_train, labels_train])
test_data = self.generator.predict([noise_test, labels_test])