aboutsummaryrefslogtreecommitdiff
path: root/cgan.py
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-02-27 18:41:15 +0000
committernunzip <np.scarh@gmail.com>2019-02-27 18:41:15 +0000
commitc5f5d81dc0233c03f339fdf932ef3f72871db3cf (patch)
tree1842bdb02413eac1d11b629297f1e2653eaaaa12 /cgan.py
parent9c26f3b6e6b317c910bf3bdafc9b070c151dff4a (diff)
downloade4-gan-c5f5d81dc0233c03f339fdf932ef3f72871db3cf.tar.gz
e4-gan-c5f5d81dc0233c03f339fdf932ef3f72871db3cf.tar.bz2
e4-gan-c5f5d81dc0233c03f339fdf932ef3f72871db3cf.zip
Reshape and flatten generated data
Diffstat (limited to 'cgan.py')
-rw-r--r--cgan.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/cgan.py b/cgan.py
index f71094c..e0c1e4e 100644
--- a/cgan.py
+++ b/cgan.py
@@ -196,13 +196,18 @@ class CGAN():
noise_train = np.random.normal(0, 1, (60000, 100))
noise_test = np.random.normal(0, 1, (10000, 100))
- gen_train = np.zeros(60000).reshape(-1, 1)
- gen_test = np.zeros(10000).reshape(-1, 1)
+ labels_train = np.zeros(60000).reshape(-1, 1)
+ labels_test = np.zeros(10000).reshape(-1, 1)
for i in range(10):
- gen_train[i*600:] = i
- gen_test[i*100:] = i
-
- return self.generator.predict([noise_train, gen_train]), self.generator.predict([noise_test, gen_test]), gen_train, gen_test
+ labels_train[i*600:] = i
+ labels_test[i*100:] = i
+ train_data = self.generator.predict([noise_train, labels_train])
+ test_data = self.generator.predict([noise_test, labels_test])
+ test_data = np.pad(test_data, ((0,0),(2,2),(2,2),(0,0)), 'constant')
+ train_data = np.pad(train_data, ((0,0),(2,2),(2,2),(0,0)), 'constant')
+ labels_test = labels_test.flatten()
+ labels_train = labels_train.flatten()
+ return train_data, test_data, labels_train, labels_test
'''
@@ -211,4 +216,4 @@ if __name__ == '__main__':
cgan.train(epochs=7000, batch_size=32, sample_interval=200)
train, test, tr_labels, te_labels = cgan.generate_data()
print(train.shape, test.shape)
-''' \ No newline at end of file
+'''