diff options
| -rw-r--r-- | cgan.py | 16 | 
1 files changed, 11 insertions, 5 deletions
| @@ -193,17 +193,23 @@ class CGAN():          plt.close()      def generate_data(self): -      noise_train = np.random.normal(0, 1, (60000, 100)) +      noise_train = np.random.normal(0, 1, (55000, 100))        noise_test = np.random.normal(0, 1, (10000, 100)) +      noise_val = np.random.normal(0, 1, (5000, 100)) -      labels_train = np.zeros(60000).reshape(-1, 1) +      labels_train = np.zeros(55000).reshape(-1, 1)        labels_test = np.zeros(10000).reshape(-1, 1) +      labels_val = np.zeros(5000).reshape(-1, 1) +        for i in range(10): -        labels_train[i*600:] = i -        labels_test[i*100:] = i +        labels_train[i*5500:] = i +        labels_test[i*1000:] = i +        labels_val[i*500:] = i +        train_data = self.generator.predict([noise_train, labels_train])        test_data = self.generator.predict([noise_test, labels_test]) -      return train_data, test_data, labels_train, labels_test +      val_data = self.generator.predict([noise_val, labels_val]) +      return train_data, test_data, val_data, labels_train, labels_test, labels_val  ''' | 
