aboutsummaryrefslogtreecommitdiff
path: root/cgan.py
diff options
context:
space:
mode:
authorVasil Zlatanov <v@skozl.com>2019-02-27 18:53:22 +0000
committerVasil Zlatanov <v@skozl.com>2019-02-27 18:53:22 +0000
commit6b5e772d9b047d8ae44bd120f6b297596e2ce463 (patch)
treee91672b069f11a6d3a7ff6512e6b015a6d6a648e /cgan.py
parenta77a1be3117b8018fdb10be3e85e258d5f5c53b4 (diff)
parent7eec33f756cffee0c4526c7e3e6af71246dd0787 (diff)
downloade4-gan-6b5e772d9b047d8ae44bd120f6b297596e2ce463.tar.gz
e4-gan-6b5e772d9b047d8ae44bd120f6b297596e2ce463.tar.bz2
e4-gan-6b5e772d9b047d8ae44bd120f6b297596e2ce463.zip
Merge branch 'master' of skozl.com:e4-gan
Diffstat (limited to 'cgan.py')
-rw-r--r--cgan.py20
1 files changed, 13 insertions, 7 deletions
diff --git a/cgan.py b/cgan.py
index f71094c..d44f961 100644
--- a/cgan.py
+++ b/cgan.py
@@ -196,13 +196,19 @@ class CGAN():
noise_train = np.random.normal(0, 1, (60000, 100))
noise_test = np.random.normal(0, 1, (10000, 100))
- gen_train = np.zeros(60000).reshape(-1, 1)
- gen_test = np.zeros(10000).reshape(-1, 1)
+ labels_train = np.zeros(60000).reshape(-1, 1)
+ labels_test = np.zeros(10000).reshape(-1, 1)
for i in range(10):
- gen_train[i*600:] = i
- gen_test[i*100:] = i
-
- return self.generator.predict([noise_train, gen_train]), self.generator.predict([noise_test, gen_test]), gen_train, gen_test
+ labels_train[i*600:] = i
+ labels_test[i*100:] = i
+ train_data = self.generator.predict([noise_train, labels_train])
+ test_data = self.generator.predict([noise_test, labels_test])
+ test_data = np.pad(test_data, ((0,0),(2,2),(2,2),(0,0)), 'constant')
+ train_data = np.pad(train_data, ((0,0),(2,2),(2,2),(0,0)), 'constant')
+ labels_test = labels_test.flatten()
+ labels_train = labels_train.flatten()
+ print(train_data.shape, test_data.shape, labels_test.shape, labes_train.shape)
+ return train_data, test_data, labels_train, labels_test
'''
@@ -211,4 +217,4 @@ if __name__ == '__main__':
cgan.train(epochs=7000, batch_size=32, sample_interval=200)
train, test, tr_labels, te_labels = cgan.generate_data()
print(train.shape, test.shape)
-''' \ No newline at end of file
+'''