aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVasil Zlatanov <v@skozl.com>2019-03-08 00:52:58 +0000
committerVasil Zlatanov <v@skozl.com>2019-03-08 00:52:58 +0000
commit2008dbfa3bb70543b8d071947e3877b01b730308 (patch)
treea4ea0f637c350c50c8b6a0dca087d5b5818788a9
parent836b0215275cfa52142a66de6978448515b17f4f (diff)
downloade4-gan-2008dbfa3bb70543b8d071947e3877b01b730308.tar.gz
e4-gan-2008dbfa3bb70543b8d071947e3877b01b730308.tar.bz2
e4-gan-2008dbfa3bb70543b8d071947e3877b01b730308.zip
Add ipynb to repo
-rw-r--r--computer_vision.ipynb311
1 files changed, 311 insertions, 0 deletions
diff --git a/computer_vision.ipynb b/computer_vision.ipynb
new file mode 100644
index 0000000..584b19d
--- /dev/null
+++ b/computer_vision.ipynb
@@ -0,0 +1,311 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "computer_vision.ipynb",
+ "version": "0.3.2",
+ "provenance": [],
+ "collapsed_sections": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "metadata": {
+ "id": "o8rKg5jPF_aa",
+ "colab_type": "code",
+ "outputId": "9569d1de-a4e6-42b0-ab60-713b627ec02d",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 53
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "import os\n",
+ "import csv\n",
+ "import numpy as np\n",
+ "\n",
+ "repo_location = os.path.join('/content', 'e4-gan')\n",
+ "print(repo_location)\n",
+ "if not os.path.exists(repo_location):\n",
+ " !git clone https://git.skozl.com/e4-gan /content/e4-gan\n",
+ " \n",
+ "os.chdir(repo_location)\n",
+ "!cd /content/e4-gan\n",
+ "!git pull\n"
+ ],
+ "execution_count": 1,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "/content/e4-gan\n",
+ "Already up to date.\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Mci7b38-bDjf",
+ "colab_type": "code",
+ "outputId": "0ec49551-a260-4469-f656-146b6a3bb226",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 161
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "imgfolder = os.path.join(repo_location, 'images')\n",
+ "print(imgfolder)\n",
+ "if not os.path.exists(imgfolder):\n",
+ " !mkdir images\n",
+ " print('Make image directory')\n",
+ " \n",
+ "from dcgan import DCGAN\n",
+ "from cgan import CGAN\n",
+ "from cdcgan import CDCGAN\n",
+ "from lenet import *\n",
+ " \n",
+ "#vbn_dcgan = DCGAN(virtual_batch_normalization=True)\n",
+ "#utils = os.path.join('/content', 'utils')\n",
+ "cgan = CGAN()\n",
+ "cdcgan = CDCGAN()\n",
+ "\n",
+ "#dcgan.train(epochs=4000, batch_size=32, save_interval=1000)\n",
+ "#cgan.train(epochs=20000, batch_size=32, sample_interval=1000, graph=True)"
+ ],
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "/content/e4-gan/images\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "Using TensorFlow backend.\n"
+ ],
+ "name": "stderr"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "Colocations handled automatically by placer.\n",
+ "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "LcifrT3feO6P",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "#cdcgan.discriminator.save_weights('disc_weights.h5')\n",
+ "#cdcgan.generator.save_weights('gen_weights.h5')"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "X_-PBBXitdui",
+ "colab_type": "code",
+ "outputId": "b49313cf-54b3-44ee-9afb-3dfe1d16906d",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 125
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "cdcgan.train(epochs=10001, batch_size=128, sample_interval=200, graph=True, smooth_real=0.9)"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "\r 0%| | 0/10001 [00:00<?, ?it/s]"
+ ],
+ "name": "stderr"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "Use tf.cast instead.\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.6/dist-packages/keras/engine/training.py:490: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?\n",
+ " 'Discrepancy between trainable weights and collected trainable'\n",
+ " 7%|▋ | 728/10001 [01:38<19:14, 8.03it/s]"
+ ],
+ "name": "stderr"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "a56uNnvlwZgt",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "#cgan.train(epochs=10000, batch_size=32, sample_interval=1000, graph=True, smooth_real=0.9)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "ZYR97BHmMuQE",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "X_train, y_train, X_validation, y_validation, X_test, y_test = import_mnist()\n",
+ "train_gen, test_gen, val_gen, tr_labels_gen, te_labels_gen, val_labels_gen = cdcgan.generate_data()\n",
+ "\n",
+ "# If split = 0 use only original mnist set\n",
+ "train_data, train_labels, val_data, val_labels = mix_data(X_train, y_train, X_validation, y_validation, train_gen, tr_labels_gen, val_gen, val_labels_gen, split=0.3)\n",
+ "print(val_data.shape, val_labels.shape, train_data.shape, train_labels.shape)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "zbrG6Uk8Tfqd",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "cdcgan.generator.save('gen.h5')"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "Lfd0uuM0m98s",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "model = train_classifier(train_data, train_labels, X_validation, y_validation, batch_size=128, epochs=100)\n",
+ "#For further steps of fine tuning use:\n",
+ "#model.fit(train_data, train_labels, batch_size=128, epochs=100, verbose=1, validation_data = (X_validation, y_validation))\n",
+ "model.save_weights('lenet.h5')"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "8e-UgoZ7et9D",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "model = get_lenet_icp((32,32,1))\n",
+ "model.load_weights('lenet.h5')"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "CiGcNvjeNOjp",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "accuracy_mnist, inception_mnist = test_classifier(model, X_test, y_test)\n",
+ "print('Accuracy', accuracy_mnist)\n",
+ "print('Inception Score', inception_mnist)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "mWPYOjK3X3cS",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "accuracy_gen, inception_gen = test_classifier(model, test_gen, te_labels_gen)\n",
+ "print('Accuracy', accuracy_gen)\n",
+ "print('Inception Score', inception_gen)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "ZxTMGlwuj9vu",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "'''\n",
+ "import matplotlib.pyplot as plt \n",
+ "\n",
+ "precision_mnist = np.array(11)\n",
+ "inception_score = np.array(11)\n",
+ "\n",
+ "for i in range(11):\n",
+ " split = float(i)/10\n",
+ " train_data, train_labels, val_data, val_labels = mix_data(X_train, y_train, X_validation, y_validation, train_gen, tr_labels_gen, val_gen, val_labels_gen, split=split)\n",
+ " model = train_classifier(train_data, train_labels, X_validation, y_validation, batch_size=128, epochs=100)\n",
+ " precision_mnist[i] = test_classifier(model, X_test, y_test)\n",
+ " inception_score[i] = test_classifier(model, test_gen, te_labels_gen)\n",
+ " \n",
+ "xgrid = 100*np.arange(11)\n",
+ "plt.plot(xgrid, 100*precision_mnist)\n",
+ "plt.plot(xgrid, 100*inception_score)\n",
+ "plt.ylabel('Classification Accuracy (%)')\n",
+ "plt.xlabel('Amount of generated data used for training')\n",
+ "plt.legend(('MNIST Test Set', 'CGAN Generated Test Set'), loc='best')\n",
+ "plt.show()\n",
+ "'''"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ }
+ ]
+} \ No newline at end of file