{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nGeneration of 3D brain MRI using VAE Generative Adversial Networks\n==================================================================\n\nCredit: A Grigis\n\nBased on:\n\n- https://github.com/cyclomon/3dbraingen\n\nThis tutorial is for the intuition of simple Generative Adversarial Networks\n(GAN) for generating  realistic  MRI images. Here, we propose a model that\ncan successfully generate 3D brain MRI data from random vectors by learning\nthe data distribution.\nAfter reading this tutorial, you'll understand the technical details needed to\nimplement VAE-GAN.\n\nLet's begin with importing stuffs:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()\n\nimport numpy as np\nimport logging\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Variable\nfrom pynet.datasets import DataManager, fetch_brats\nfrom pynet.interfaces import DeepLearningInterface\nfrom pynet.plotting import Board, update_board\nfrom pynet.utils import setup_logging\nfrom pynet.preprocessing.spatial import downsample\nfrom pynet.models import BGDiscriminator, BGEncoder, BGGenerator\n\n\n# Global parameters\nlogger = logging.getLogger(\"pynet\")\nsetup_logging(level=\"debug\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The model will be trained on BRATS\n\nWe will train the model to synthesize brain disorder MRI data (Glioma).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "data = fetch_brats(\n    datasetdir=\"/neurospin/nsap/processed/deepbrain/tumor/data/brats\")\nbatch_size = 4\n\ndef transformer(data, imgtype=\"flair\"):\n    typemap = {\n        \"t1\": 0, \"t1ce\": 1, \"t2\": 2, \"flair\": 3}\n    if imgtype is None:\n        imgtype = range(4)\n    else:\n        if not isinstance(imgtype, list):   \n            imgtype = [imgtype]\n        imgtype = [typemap[key] for key in imgtype]\n    transformed_data = []\n    for channel_id in range(len(data)):\n        if channel_id not in imgtype:\n            continue\n        arr = data[channel_id]\n        transformed_data.append(downsample(arr, scale=3))\n    return np.asarray(transformed_data)\n\nmanager = DataManager(\n    input_path=data.input_path,\n    metadata_path=data.metadata_path,\n    stratify_label=\"grade\",\n    number_of_folds=10,\n    batch_size=batch_size,\n    test_size=0,\n    input_transforms=[transformer],\n    sample_size=0.2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Loss\n----\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "criterion_bce = nn.BCELoss()\ncriterion_l1 = nn.L1Loss()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Training\n--------\n\nWe'll train the encoder, generator and discriminator to optimize the losses \nusing Adam optimizer.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "n_epochs = 100\nlatent_dim = 1000\nuse_cuda = False\nchannels = 1\nin_shape = (50, 64, 45) # (150, 190, 135)\ngamma = 20\nbeta = 10\ndevice = torch.device(\"cuda\" if use_cuda else \"cpu\")\ngenerator = BGGenerator(\n    in_shape=in_shape, out_channels=channels, start_filts=64,\n    latent_dim=latent_dim, mode=\"trilinear\", with_code=False).to(device)\ndiscriminator = BGDiscriminator(\n    in_shape=in_shape, in_channels=channels, out_channels=channels,\n    start_filts=64, with_logit=True).to(device)\nencoder = BGEncoder(\n    in_shape=in_shape, in_channels=channels, start_filts=64,\n    latent_dim=latent_dim).to(device)\ng_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0001)\nd_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0001)\ne_optimizer = torch.optim.Adam(encoder.parameters(), lr = 0.0001)\nreal_y = Variable(torch.ones((batch_size, channels)).to(device))\nfake_y = Variable(torch.zeros((batch_size, channels)).to(device))\nboard = Board(port=8097, host=\"http://localhost\", env=\"vae\")\noutdir = \"/tmp/vae-gan/checkpoint\"\nif not os.path.isdir(outdir):\n    os.makedirs(outdir)\n\nfor epoch in range(n_epochs):\n    loaders = manager.get_dataloader(train=True, validation=False,\n                                     fold_index=0)\n    for iteration, item in enumerate(loaders.train):\n        real_images = item.inputs.to(device)\n        batch_size = real_images.size(0)\n        real_images = Variable(real_images,requires_grad=False).to(device)\n        z_rand = Variable(torch.randn(\n            (batch_size, latent_dim)), requires_grad=False).to(device)\n        mean, logvar, code = encoder(real_images)\n        x_rec = generator(code)\n        x_rand = generator(z_rand)\n        logger.debug(\"X_real: {0}\".format(real_images.shape))\n        logger.debug(\"X_rand: {0}\".format(x_rand.shape))\n        logger.debug(\"X_rec: {0}\".format(x_rec.shape))\n\n        # Train discriminator \n        d_optimizer.zero_grad()\n        d_real_loss = criterion_bce(\n            discriminator(real_images), real_y[:batch_size])\n        d_recon_loss = criterion_bce(discriminator(x_rec), fake_y[:batch_size])\n        d_fake_loss = criterion_bce(discriminator(x_rand), fake_y[:batch_size])\n        dis_loss = d_recon_loss + d_real_loss + d_fake_loss\n        dis_loss.backward(retain_graph=True)\n        d_optimizer.step()\n        \n        # Train generator\n        g_optimizer.zero_grad()\n        output = discriminator(real_images)\n        d_real_loss = criterion_bce(output, real_y[:batch_size])\n        output = discriminator(x_rec)\n        d_recon_loss = criterion_bce(output, fake_y[:batch_size])\n        output = discriminator(x_rand)\n        d_fake_loss = criterion_bce(output, fake_y[:batch_size])\n        d_img_loss = d_real_loss + d_recon_loss + d_fake_loss\n        gen_img_loss = -d_img_loss\n        rec_loss = ((x_rec - real_images)**2).mean()\n        err_dec = gamma * rec_loss + gen_img_loss\n        err_dec.backward(retain_graph=True)\n        g_optimizer.step()\n\n        # Train encoder\n        prior_loss = 1 + logvar-mean.pow(2) - logvar.exp()\n        prior_loss = (-0.5 * torch.sum(prior_loss)) / torch.numel(mean.data)\n        err_enc = prior_loss + beta * rec_loss\n        e_optimizer.zero_grad()\n        err_enc.backward()\n        e_optimizer.step()\n\n        # Visualization \n        if iteration % 4 == 0:\n            print(\"[{0}/{1}]\".format(epoch, n_epochs),\n                  \"D: {:<8.3}\".format(dis_loss.item()), \n                  \"En: {:<8.3}\".format(err_enc.item()),\n                  \"De: {:<8.3}\".format(err_dec.item()))\n            \n            for name, data in [(\"X_real\", real_images), (\"X_dec\", x_rec),\n                               (\"X_rand\", x_rand)]:\n                featmask = (0.5 * data[0] + 0.5).data.cpu().numpy()\n                img = featmask[..., featmask.shape[-1] // 2]\n                img = np.expand_dims(img, axis=1)\n                img = (img / img.max()) * 255\n                board.viewer.images(\n                    img,\n                    opts={\n                        \"title\": name,\n                        \"caption\": name},\n                    win=name)\n\n    # Save model\n    for name, model in [(\"generator\", generator),\n                        (\"discriminator\", discriminator),\n                        (\"encoder\", encoder)]:\n        fname = os.path.join(\n            outdir, name + \"_epoch_\" + str(epoch + 1) + \".pth\")\n        torch.save(model.state_dict(), fname)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Conclusion\n----------\n\nVariational Auto-Encoder(VAE) GAN are free from mode collapse but outputs\nare characterized with blurriness. In order to effectively address the\nproblems of both mode collapse of GANs and blurriness of VAEs, we will\nuse \u03b1-GAN, a solution born by combining both models, in the next tutorial.\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}