diff options
Diffstat (limited to 'cgan.py')
-rw-r--r-- | cgan.py | 19 |
1 files changed, 12 insertions, 7 deletions
@@ -196,13 +196,18 @@ class CGAN(): noise_train = np.random.normal(0, 1, (60000, 100)) noise_test = np.random.normal(0, 1, (10000, 100)) - gen_train = np.zeros(60000).reshape(-1, 1) - gen_test = np.zeros(10000).reshape(-1, 1) + labels_train = np.zeros(60000).reshape(-1, 1) + labels_test = np.zeros(10000).reshape(-1, 1) for i in range(10): - gen_train[i*600:] = i - gen_test[i*100:] = i - - return self.generator.predict([noise_train, gen_train]), self.generator.predict([noise_test, gen_test]), gen_train, gen_test + labels_train[i*600:] = i + labels_test[i*100:] = i + train_data = self.generator.predict([noise_train, labels_train]) + test_data = self.generator.predict([noise_test, labels_test]) + test_data = np.pad(test_data, ((0,0),(2,2),(2,2),(0,0)), 'constant') + train_data = np.pad(train_data, ((0,0),(2,2),(2,2),(0,0)), 'constant') + labels_test = labels_test.flatten() + labels_train = labels_train.flatten() + return train_data, test_data, labels_train, labels_test ''' @@ -211,4 +216,4 @@ if __name__ == '__main__': cgan.train(epochs=7000, batch_size=32, sample_interval=200) train, test, tr_labels, te_labels = cgan.generate_data() print(train.shape, test.shape) -'''
\ No newline at end of file +''' |