aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-02-27 20:25:24 +0000
committernunzip <np.scarh@gmail.com>2019-02-27 20:25:24 +0000
commita7cb76b4131a7b5b142ac26aa2d47f7e8097c0db (patch)
tree82fc4967f3157a60bb0c3baa1467166556a84ac6
parentf5e7e167119dff9e2bae122f44fc2172b0cae14b (diff)
downloade4-gan-a7cb76b4131a7b5b142ac26aa2d47f7e8097c0db.tar.gz
e4-gan-a7cb76b4131a7b5b142ac26aa2d47f7e8097c0db.tar.bz2
e4-gan-a7cb76b4131a7b5b142ac26aa2d47f7e8097c0db.zip
Add validation output
-rw-r--r--cgan.py16
1 files changed, 11 insertions, 5 deletions
diff --git a/cgan.py b/cgan.py
index 7bef77c..17aa367 100644
--- a/cgan.py
+++ b/cgan.py
@@ -193,17 +193,23 @@ class CGAN():
plt.close()
def generate_data(self):
- noise_train = np.random.normal(0, 1, (60000, 100))
+ noise_train = np.random.normal(0, 1, (55000, 100))
noise_test = np.random.normal(0, 1, (10000, 100))
+ noise_val = np.random.normal(0, 1, (5000, 100))
- labels_train = np.zeros(60000).reshape(-1, 1)
+ labels_train = np.zeros(55000).reshape(-1, 1)
labels_test = np.zeros(10000).reshape(-1, 1)
+ labels_val = np.zeros(5000).reshape(-1, 1)
+
for i in range(10):
- labels_train[i*600:] = i
- labels_test[i*100:] = i
+ labels_train[i*5500:] = i
+ labels_test[i*1000:] = i
+ labels_val[i*500:] = i
+
train_data = self.generator.predict([noise_train, labels_train])
test_data = self.generator.predict([noise_test, labels_test])
- return train_data, test_data, labels_train, labels_test
+ val_data = self.generator.predict([noise_train, labels_val])
+ return train_data, test_data, val_data, labels_train, labels_test, labels_val
'''