aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xncdcgan.py22
1 files changed, 12 insertions, 10 deletions
diff --git a/ncdcgan.py b/ncdcgan.py
index ccb99d3..97b137b 100755
--- a/ncdcgan.py
+++ b/ncdcgan.py
@@ -234,19 +234,21 @@ class nCDCGAN():
fig.savefig("images/%d.png" % epoch)
plt.close()
- def generate_data(self):
- noise_train = np.random.normal(0, 1, (55000, 100))
- noise_test = np.random.normal(0, 1, (10000, 100))
- noise_val = np.random.normal(0, 1, (5000, 100))
+ def generate_data(self, out=55000):
+ v_out = int(out/11)
+ te_out = v_out*2
+ noise_train = np.random.normal(0, 1, (out, 100))
+ noise_test = np.random.normal(0, 1, (te_out, 100))
+ noise_val = np.random.normal(0, 1, (v_out, 100))
- labels_train = np.zeros(55000).reshape(-1, 1)
- labels_test = np.zeros(10000).reshape(-1, 1)
- labels_val = np.zeros(5000).reshape(-1, 1)
+ labels_train = np.zeros(out).reshape(-1, 1)
+ labels_test = np.zeros(te_out).reshape(-1, 1)
+ labels_val = np.zeros(v_out).reshape(-1, 1)
for i in range(10):
- labels_train[i*5500:-1] = i
- labels_test[i*1000:-1] = i
- labels_val[i*500:-1] = i
+ labels_train[i*int(out/10):-1] = i
+ labels_test[i*int(te_out/10):-1] = i
+ labels_val[i*int(v_out/10):-1] = i
train_data = self.generator.predict([noise_train, labels_train])
test_data = self.generator.predict([noise_test, labels_test])