diff options
Diffstat (limited to 'ncdcgan.py')
| -rwxr-xr-x | ncdcgan.py | 22 | 
1 files changed, 12 insertions, 10 deletions
@@ -234,19 +234,21 @@ class nCDCGAN():          fig.savefig("images/%d.png" % epoch)          plt.close() -    def generate_data(self): -      noise_train = np.random.normal(0, 1, (55000, 100)) -      noise_test = np.random.normal(0, 1, (10000, 100)) -      noise_val = np.random.normal(0, 1, (5000, 100)) +    def generate_data(self, out=55000): +      v_out = int(out/11) +      te_out = v_out*2 +      noise_train = np.random.normal(0, 1, (out, 100)) +      noise_test = np.random.normal(0, 1, (te_out, 100)) +      noise_val = np.random.normal(0, 1, (v_out, 100)) -      labels_train = np.zeros(55000).reshape(-1, 1) -      labels_test = np.zeros(10000).reshape(-1, 1) -      labels_val = np.zeros(5000).reshape(-1, 1) +      labels_train = np.zeros(out).reshape(-1, 1) +      labels_test = np.zeros(te_out).reshape(-1, 1) +      labels_val = np.zeros(v_out).reshape(-1, 1)        for i in range(10): -        labels_train[i*5500:-1] = i -        labels_test[i*1000:-1] = i -        labels_val[i*500:-1] = i +        labels_train[i*int(out/10):-1] = i +        labels_test[i*int(te_out/10):-1] = i +        labels_val[i*int(v_out/10):-1] = i        train_data = self.generator.predict([noise_train, labels_train])        test_data = self.generator.predict([noise_test, labels_test])  | 
