aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-03-10 18:15:14 +0000
committernunzip <np.scarh@gmail.com>2019-03-10 18:15:14 +0000
commit4f96cbcc673f46d8fdd5bc7bc253aaeb2aa2af88 (patch)
tree29e2a0d36f5eaa9283f9eaf9e7564f15650ad889
parent288d7ff75a1acd017ba3a498f508292d5328bdd3 (diff)
downloade4-gan-4f96cbcc673f46d8fdd5bc7bc253aaeb2aa2af88.tar.gz
e4-gan-4f96cbcc673f46d8fdd5bc7bc253aaeb2aa2af88.tar.bz2
e4-gan-4f96cbcc673f46d8fdd5bc7bc253aaeb2aa2af88.zip
Rewrite generate data
-rwxr-xr-xcgan.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/cgan.py b/cgan.py
index 0e3dacd..a34a0e3 100755
--- a/cgan.py
+++ b/cgan.py
@@ -194,21 +194,23 @@ class CGAN():
fig.savefig("images/%d.png" % epoch)
plt.close()
- def generate_data(self, split):
- train_size = int((55000*100/split)-55000)
- val_size = int(train_size/11)
+ def generate_data(self, output_train = 55000):
+ # with this output_train you specify how much training data you want. the other two variables produce validation
+ # and testing data in proportions equal to the ones of MNIST dataset
+
+ val_size = int(output_train/11)
test_size = 2*val_size
- noise_train = np.random.normal(0, 1, (train_size, 100))
+ noise_train = np.random.normal(0, 1, (output_train, 100))
noise_test = np.random.normal(0, 1, (test_size, 100))
noise_val = np.random.normal(0, 1, (val_size, 100))
- labels_train = np.zeros(train_size).reshape(-1, 1)
+ labels_train = np.zeros(output_train).reshape(-1, 1)
labels_test = np.zeros(test_size).reshape(-1, 1)
labels_val = np.zeros(val_size).reshape(-1, 1)
for i in range(10):
- labels_train[i*int(train_size/10):-1] = i
+ labels_train[i*int(output_train/10):-1] = i
labels_test[i*int(test_size/10):-1] = i
labels_val[i*int(val_size/10):-1] = i