{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nGeneration of 3D brain MRI using VAE Generative Adversial Networks\n==================================================================\n\nCredit: A Grigis\n\nBased on:\n\n- https://github.com/cyclomon/3dbraingen\n\nThis tutorial is for the intuition of simple Generative Adversarial Networks\n(GAN) for generating  realistic  MRI images. Here, we propose a model that\ncan successfully generate 3D brain MRI data by integrating a code\ndiscriminator.\n\nLet's begin with importing stuffs:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()\n\nimport numpy as np\nimport logging\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Variable\nfrom pynet.datasets import DataManager, fetch_brats\nfrom pynet.interfaces import DeepLearningInterface\nfrom pynet.plotting import Board, update_board\nfrom pynet.utils import setup_logging\nfrom pynet.preprocessing.spatial import downsample\nfrom pynet.models import BGDiscriminator, BGGenerator, BGCodeDiscriminator\n\n\n# Global parameters\nlogger = logging.getLogger(\"pynet\")\nsetup_logging(level=\"debug\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The model will be trained on BRATS\n\nWe will train the model to synthesize brain disorder MRI data (Glioma).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "data = fetch_brats(\n    datasetdir=\"/neurospin/nsap/processed/deepbrain/tumor/data/brats\")\nbatch_size = 4\n\ndef transformer(data, imgtype=\"flair\"):\n    typemap = {\n        \"t1\": 0, \"t1ce\": 1, \"t2\": 2, \"flair\": 3}\n    if imgtype is None:\n        imgtype = range(4)\n    else:\n        if not isinstance(imgtype, list):   \n            imgtype = [imgtype]\n        imgtype = [typemap[key] for key in imgtype]\n    transformed_data = []\n    for channel_id in range(len(data)):\n        if channel_id not in imgtype:\n            continue\n        arr = data[channel_id]\n        transformed_data.append(downsample(arr, scale=3))\n    return np.asarray(transformed_data)\n\nmanager = DataManager(\n    input_path=data.input_path,\n    metadata_path=data.metadata_path,\n    stratify_label=\"grade\",\n    number_of_folds=10,\n    batch_size=batch_size,\n    test_size=0,\n    input_transforms=[transformer],\n    sample_size=0.2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Loss\n----\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def calc_gradient_penalty(model, x, x_gen, w=10):\n    \"\"\" WGAN-GP gradient penalty.\n    \"\"\"\n    assert (x.size() == x_gen.size()), \"Real and sampled sizes do not match.\"\n    alpha_size = tuple((len(x), *(1, ) * (x.dim() - 1)))\n    alpha_t = torch.cuda.FloatTensor if x.is_cuda else torch.Tensor\n    alpha = alpha_t(*alpha_size).uniform_()\n    x_hat = x.data * alpha + x_gen.data * (1 - alpha)\n    x_hat = Variable(x_hat, requires_grad=True)\n\n    def eps_norm(x):\n        x = x.view(len(x), -1)\n        return (x * x + eps).sum(-1).sqrt()\n\n    def bi_penalty(x):\n        return (x - 1)**2\n\n    grad_xhat = torch.autograd.grad(\n        model(x_hat).sum(), x_hat, create_graph=True, only_inputs=True)[0]\n\n    penalty = w * bi_penalty(eps_norm(grad_xhat)).mean()\n\n    return penalty\n\ncriterion_bce = nn.BCELoss()\ncriterion_l1 = nn.L1Loss()\ncriterion_mse = nn.MSELoss()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Training\n--------\n\nWe'll train the encoder, generator and discriminator to optimize the losses \nusing Adam optimizer.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def infinite_train_generartor(data_loader):\n    while True:\n        for _, data in enumerate(data_loader):\n            yield data.inputs\n\nlatent_dim = 1000\nuse_cuda = False\nchannels = 1\nin_shape = (50, 64, 45) # (150, 190, 135)\nbeta = 10\neps = 1e-15\ndevice = torch.device(\"cuda\" if use_cuda else \"cpu\")\ngenerator = BGGenerator(\n    in_shape=in_shape, out_channels=channels, start_filts=64,\n    latent_dim=latent_dim, mode=\"trilinear\", with_code=True).to(device)\ncode_discriminator = BGCodeDiscriminator(\n    out_channels=channels, code_size=latent_dim, n_units=4096).to(device)\ndiscriminator = BGDiscriminator(\n    in_shape=in_shape, in_channels=channels, out_channels=channels,\n    start_filts=64, with_logit=False).to(device)\nencoder = BGDiscriminator(\n    in_shape=in_shape, in_channels=channels, out_channels=latent_dim,\n    start_filts=64, with_logit=False).to(device)\ng_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)\ncd_optimizer = torch.optim.Adam(code_discriminator.parameters(), lr=0.0002)\nd_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)\ne_optimizer = torch.optim.Adam(encoder.parameters(), lr = 0.0002)\nreal_y = Variable(torch.ones((batch_size, channels)).to(\n    device, non_blocking=True))\nfake_y = Variable(torch.zeros((batch_size, channels)).to(\n    device, non_blocking=True))\nboard = Board(port=8097, host=\"http://localhost\", env=\"vae\")\noutdir = \"/tmp/vae-gan/checkpoint\"\nif not os.path.isdir(outdir):\n    os.makedirs(outdir)\n\ng_iter = 1\nd_iter = 1\ncd_iter = 1\ntotal_iter = 200000\ntrain_loader = manager.get_dataloader(train=True, validation=False,\n                                      fold_index=0).train\nloader = infinite_train_generartor(train_loader)\n\nfor iteration in range(total_iter):\n\n    # Train Encoder - Generator\n    for model, with_grad in [(discriminator, False),\n                             (code_discriminator, False),\n                             (encoder, True),\n                             (generator, True)]:\n        for param in model.parameters():  \n            param.requires_grad = with_grad\n\n    for iters in range(g_iter):\n        generator.zero_grad()\n        encoder.zero_grad()\n        real_images = loader.__next__()\n        batch_size = real_images.size(0)\n        real_images = Variable(real_images, volatile=True).to(\n            device, non_blocking=True)\n        z_rand = Variable(\n            torch.randn((batch_size,latent_dim)), volatile=True).to(device)\n        z_hat = encoder(real_images).view(batch_size, -1)\n        x_hat = generator(z_hat)\n        x_rand = generator(z_rand)\n        c_loss = - code_discriminator(z_hat).mean()\n\n        d_real_loss = discriminator(x_hat).mean()\n        d_fake_loss = discriminator(x_rand).mean()\n        d_loss = - d_fake_loss - d_real_loss\n        l1_loss = 10 * criterion_l1(x_hat, real_images)\n        loss1 = l1_loss + c_loss + d_loss\n\n        if iters < (g_iter - 1):\n            loss1.backward()\n        else:\n            loss1.backward(retain_graph=True)\n        e_optimizer.step()\n        g_optimizer.step()\n        g_optimizer.step()\n\n    # Train discriminator\n    for model, with_grad in [(discriminator, True),\n                             (code_discriminator, False),\n                             (encoder, False),\n                             (generator, False)]:\n        for param in model.parameters():  \n            param.requires_grad = with_grad\n\n    for iters in range(d_iter):\n        d_optimizer.zero_grad()\n        real_images = loader.__next__()\n        batch_size = real_images.size(0)\n        z_rand = Variable(\n            torch.randn((batch_size, latent_dim)),volatile=True).to(device)\n        real_images = Variable(real_images, volatile=True).to(\n            device, non_blocking=True)\n        z_hat = encoder(real_images).view(batch_size,-1)\n        x_hat = generator(z_hat)\n        x_rand = generator(z_rand)\n        x_loss2 = (-2 * discriminator(real_images).mean() +\n                   discriminator(x_hat).mean() +\n                   discriminator(x_rand).mean())\n        gradient_penalty_r = calc_gradient_penalty(\n            discriminator, real_images.data, x_rand.data)\n        gradient_penalty_h = calc_gradient_penalty(\n            discriminator, real_images.data, x_hat.data)\n\n        loss2 = x_loss2 + gradient_penalty_r + gradient_penalty_h\n        loss2.backward(retain_graph=True)\n        d_optimizer.step()\n\n    # Train code discriminator\n    for model, with_grad in [(discriminator, False),\n                             (code_discriminator, True),\n                             (encoder, False),\n                             (generator, False)]:\n        for param in model.parameters():  \n            param.requires_grad = with_grad\n\n    for iters in range(cd_iter):\n        cd_optimizer.zero_grad()\n        z_rand = Variable(\n            torch.randn((batch_size, latent_dim)), volatile=True).to(device)\n        gradient_penalty_cd = calc_gradient_penalty(\n            code_discriminator, z_hat.data, z_rand.data)\n        loss3 = (- code_discriminator(z_rand).mean() -\n                 c_loss + gradient_penalty_cd)\n\n        loss3.backward(retain_graph=True)\n        cd_optimizer.step()\n\n    # Visualization \n    if iteration % 4 == 0:\n        print(\"[{0}/{1}]\".format(iteration, total_iter),\n              \"D: {:<8.3}\".format(loss2.item()), \n              \"En Ge: {:<8.3}\".format(loss1.item()),\n              \"Code: {:<8.3}\".format(loss3.item()))\n        \n        for name, data in [(\"X_real\", real_images), (\"X_dec\", x_hat),\n                           (\"X_rand\", x_rand)]:\n            featmask = (0.5 * data[0] + 0.5).data.cpu().numpy()\n            img = featmask[..., featmask.shape[-1] // 2]\n            img = np.expand_dims(img, axis=1)\n            img = (img / img.max()) * 255\n            board.viewer.images(\n                img,\n                opts={\n                    \"title\": name,\n                    \"caption\": name},\n                win=name)\n\n    # Save model\n    if (iteration + 1) % 100 == 0: \n        for name, model in [(\"generator\", generator),\n                            (\"code_discriminator\", code_discriminator),\n                            (\"discriminator\", discriminator),\n                            (\"encoder\", encoder)]:\n            fname = os.path.join(\n                outdir, name + \"_epoch_\" + str(iteration + 1) + \".pth\")\n            torch.save(model.state_dict(), fname)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Conclusion\n----------\n\nVariational Auto-Encoder(VAE) GAN are free from mode collapse but outputs\nare characterized with blurriness. In order to effectively address the\nproblems of both mode collapse of GANs and blurriness of VAEs, we will\nuse \u03b1-GAN, a solution born by combining both models, in the next tutorial.\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}