diff options
| -rw-r--r-- | cgan.py | 20 | 
1 files changed, 13 insertions, 7 deletions
| @@ -196,13 +196,19 @@ class CGAN():        noise_train = np.random.normal(0, 1, (60000, 100))        noise_test = np.random.normal(0, 1, (10000, 100)) -      gen_train = np.zeros(60000).reshape(-1, 1) -      gen_test = np.zeros(10000).reshape(-1, 1) +      labels_train = np.zeros(60000).reshape(-1, 1) +      labels_test = np.zeros(10000).reshape(-1, 1)        for i in range(10): -        gen_train[i*600:] = i -        gen_test[i*100:] = i - -      return self.generator.predict([noise_train, gen_train]), self.generator.predict([noise_test, gen_test]), gen_train, gen_test +        labels_train[i*600:] = i +        labels_test[i*100:] = i +      train_data = self.generator.predict([noise_train, labels_train]) +      test_data = self.generator.predict([noise_test, labels_test]) +      test_data = np.pad(test_data, ((0,0),(2,2),(2,2),(0,0)), 'constant') +      train_data = np.pad(train_data, ((0,0),(2,2),(2,2),(0,0)), 'constant') +      labels_test = labels_test.flatten() +      labels_train = labels_train.flatten() +      print(train_data.shape, test_data.shape, labels_test.shape, labes_train.shape) +      return train_data, test_data, labels_train, labels_test  ''' @@ -211,4 +217,4 @@ if __name__ == '__main__':      cgan.train(epochs=7000, batch_size=32, sample_interval=200)      train, test, tr_labels, te_labels = cgan.generate_data()      print(train.shape, test.shape) -'''
\ No newline at end of file +''' | 
