From defc939aac1ba77f8cb87b97b1a111ce23d73c52 Mon Sep 17 00:00:00 2001 From: Vasil Zlatanov Date: Thu, 7 Mar 2019 17:08:24 +0000 Subject: Add keras import in cdcgan --- cdcgan.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/cdcgan.py b/cdcgan.py index 8a4b168..7b517ca 100755 --- a/cdcgan.py +++ b/cdcgan.py @@ -2,13 +2,14 @@ from __future__ import print_function, division import tensorflow as keras import tensorflow as tf -from tensorflow.keras.datasets import mnist -from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply -from tensorflow.keras.layers import BatchNormalization, Embedding, Activation, ZeroPadding2D -from tensorflow.keras.layers import LeakyReLU -from tensorflow.keras.layers import UpSampling2D, Conv2D -from tensorflow.keras.models import Sequential, Model -from tensorflow.keras.optimizers import Adam +import tensorflow.keras as keras +from keras.datasets import mnist +from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply +from keras.layers import BatchNormalization, Embedding, Activation, ZeroPadding2D +from keras.layers import LeakyReLU +from keras.layers import UpSampling2D, Conv2D +from keras.models import Sequential, Model +from keras.optimizers import Adam import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec @@ -243,6 +244,8 @@ class CDCGAN(): return train_data, test_data, val_data, labels_train, labels_test, labels_val +''' if __name__ == '__main__': cdcgan = CDCGAN() cdcgan.train(epochs=4000, batch_size=32) +''' -- cgit v1.2.3