aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cgan.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/cgan.py b/cgan.py
index f71094c..e0c1e4e 100644
--- a/cgan.py
+++ b/cgan.py
@@ -196,13 +196,18 @@ class CGAN():
noise_train = np.random.normal(0, 1, (60000, 100))
noise_test = np.random.normal(0, 1, (10000, 100))
- gen_train = np.zeros(60000).reshape(-1, 1)
- gen_test = np.zeros(10000).reshape(-1, 1)
+ labels_train = np.zeros(60000).reshape(-1, 1)
+ labels_test = np.zeros(10000).reshape(-1, 1)
for i in range(10):
- gen_train[i*600:] = i
- gen_test[i*100:] = i
-
- return self.generator.predict([noise_train, gen_train]), self.generator.predict([noise_test, gen_test]), gen_train, gen_test
+ labels_train[i*600:] = i
+ labels_test[i*100:] = i
+ train_data = self.generator.predict([noise_train, labels_train])
+ test_data = self.generator.predict([noise_test, labels_test])
+ test_data = np.pad(test_data, ((0,0),(2,2),(2,2),(0,0)), 'constant')
+ train_data = np.pad(train_data, ((0,0),(2,2),(2,2),(0,0)), 'constant')
+ labels_test = labels_test.flatten()
+ labels_train = labels_train.flatten()
+ return train_data, test_data, labels_train, labels_test
'''
@@ -211,4 +216,4 @@ if __name__ == '__main__':
cgan.train(epochs=7000, batch_size=32, sample_interval=200)
train, test, tr_labels, te_labels = cgan.generate_data()
print(train.shape, test.shape)
-''' \ No newline at end of file
+'''