From c5f5d81dc0233c03f339fdf932ef3f72871db3cf Mon Sep 17 00:00:00 2001 From: nunzip Date: Wed, 27 Feb 2019 18:41:15 +0000 Subject: Reshape and flatten generated data --- cgan.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/cgan.py b/cgan.py index f71094c..e0c1e4e 100644 --- a/cgan.py +++ b/cgan.py @@ -196,13 +196,18 @@ class CGAN(): noise_train = np.random.normal(0, 1, (60000, 100)) noise_test = np.random.normal(0, 1, (10000, 100)) - gen_train = np.zeros(60000).reshape(-1, 1) - gen_test = np.zeros(10000).reshape(-1, 1) + labels_train = np.zeros(60000).reshape(-1, 1) + labels_test = np.zeros(10000).reshape(-1, 1) for i in range(10): - gen_train[i*600:] = i - gen_test[i*100:] = i - - return self.generator.predict([noise_train, gen_train]), self.generator.predict([noise_test, gen_test]), gen_train, gen_test + labels_train[i*600:] = i + labels_test[i*100:] = i + train_data = self.generator.predict([noise_train, labels_train]) + test_data = self.generator.predict([noise_test, labels_test]) + test_data = np.pad(test_data, ((0,0),(2,2),(2,2),(0,0)), 'constant') + train_data = np.pad(train_data, ((0,0),(2,2),(2,2),(0,0)), 'constant') + labels_test = labels_test.flatten() + labels_train = labels_train.flatten() + return train_data, test_data, labels_train, labels_test ''' @@ -211,4 +216,4 @@ if __name__ == '__main__': cgan.train(epochs=7000, batch_size=32, sample_interval=200) train, test, tr_labels, te_labels = cgan.generate_data() print(train.shape, test.shape) -''' \ No newline at end of file +''' -- cgit v1.2.3-54-g00ecf From 7eec33f756cffee0c4526c7e3e6af71246dd0787 Mon Sep 17 00:00:00 2001 From: nunzip Date: Wed, 27 Feb 2019 18:48:24 +0000 Subject: debug cgan --- cgan.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cgan.py b/cgan.py index e0c1e4e..d44f961 100644 --- a/cgan.py +++ b/cgan.py @@ -207,6 +207,7 @@ class CGAN(): train_data = np.pad(train_data, ((0,0),(2,2),(2,2),(0,0)), 'constant') labels_test = labels_test.flatten() labels_train = labels_train.flatten() + print(train_data.shape, test_data.shape, labels_test.shape, labes_train.shape) return train_data, test_data, labels_train, labels_test -- cgit v1.2.3-54-g00ecf