aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornunzip <np.scarh@gmail.com>2019-02-27 22:49:16 +0000
committernunzip <np.scarh@gmail.com>2019-02-27 22:49:16 +0000
commit367167680e156ac611c5f1db9f9ff7e66d51a8fe (patch)
treed57ac7e74ff5e07d9a66ffe2238378e8cc3946c2
parentc7c740caff2afc4d615e289ef147c6228cca4a0e (diff)
downloade4-gan-367167680e156ac611c5f1db9f9ff7e66d51a8fe.tar.gz
e4-gan-367167680e156ac611c5f1db9f9ff7e66d51a8fe.tar.bz2
e4-gan-367167680e156ac611c5f1db9f9ff7e66d51a8fe.zip
Reshape labels after predict
-rw-r--r--cgan.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/cgan.py b/cgan.py
index ad3d194..fa05311 100644
--- a/cgan.py
+++ b/cgan.py
@@ -207,13 +207,14 @@ class CGAN():
labels_test[i*1000:] = i
labels_val[i*500:] = i
+ train_data = self.generator.predict([noise_train, labels_train])
+ test_data = self.generator.predict([noise_test, labels_test])
+ val_data = self.generator.predict([noise_val, labels_val])
+
labels_train = keras.utils.to_categorical(labels_train, 10)
labels_test = keras.utils.to_categorical(labels_test, 10)
labels_val = keras.utils.to_categorical(labels_val, 10)
- train_data = self.generator.predict([noise_train, labels_train])
- test_data = self.generator.predict([noise_test, labels_test])
- val_data = self.generator.predict([noise_val, labels_val])
return train_data, test_data, val_data, labels_train, labels_test, labels_val