{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nVanilla Variational AutoEncoder (VAE)\n=====================================\n\nCredit: A Grigis\n\nBased on:\n\n- https://ravirajag.dev\n\nThis tutorial is for the intuition of simple Variational Autoencoder (VAE)\nimplementation in pynet.\nAfter reading this tutorial, you'll understand the technical details needed to\nimplement VAE.\n\nLet's begin with importing stuffs:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()\n\nimport numpy as np\nfrom scipy import ndimage\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as func\nfrom pynet.datasets import DataManager, fetch_minst\nfrom pynet.interfaces import DeepLearningInterface\nfrom pynet.plotting import Board, update_board"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The model will be trained on MNIST - handwritten digits dataset. The input\nis an image in R(28\u00d728).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def flatten(arr):\n    return arr.flatten()\n\ndata = fetch_minst(datasetdir=\"/neurospin/nsap/datasets/minst\")\nmanager = DataManager(\n    input_path=data.input_path,\n    metadata_path=data.metadata_path,\n    stratify_label=\"label\",\n    number_of_folds=10,\n    batch_size=64,\n    test_size=0,\n    input_transforms=[flatten],\n    add_input=True,\n    sample_size=0.05)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The Model\n---------\n\nThe model is composed of two sub-networks:\n\n1. Given x (image), encode it into a distribution over the latent space -\n   referred to as Q(z|x).\n2. Given z in latent space (code representation of an image), decode it into\n   the image it represents - referred to as f(z).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class Encoder(nn.Module):\n    \"\"\" This the encoder part of VAE.\n    \"\"\"\n    def __init__(self, input_dim, hidden_dim, latent_dim, dropout):\n        \"\"\" Init class.\n\n        Parameters\n        ----------\n        input_dim: int\n            the size of input (in case of MNIST 28 * 28).\n        hidden_dim: int\n            the size of hidden dimension.\n        latent_dim: int\n            the latent dimension.\n        dropout: float\n            the dropout rate (trick for missing data).\n        \"\"\"\n        super(Encoder, self).__init__()\n        self.linear = nn.Linear(input_dim, hidden_dim)\n        self.mu = nn.Linear(hidden_dim, latent_dim)\n        self.logvar = nn.Linear(hidden_dim, latent_dim)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        # x is of shape [batch_size, input_dim]\n        hidden = func.relu(self.linear(self.dropout(x)))\n        # hidden is of shape [batch_size, hidden_dim]\n        z_mu = self.mu(hidden)\n        # z_mu is of shape [batch_size, latent_dim]\n        z_var = self.logvar(hidden)\n        # z_var is of shape [batch_size, latent_dim]: this is log(var)\n\n        return z_mu, z_var\n\n\nclass Decoder(nn.Module):\n    \"\"\" This the decoder part of VAE.\n    \"\"\"\n    def __init__(self, latent_dim, hidden_dim, output_dim):\n        \"\"\" Init class.\n\n        Parameters\n        ----------\n        latent_dim: int\n            the latent size.\n        hidden_dim: int\n            the size of hidden dimension.\n        output_dim: int\n            the output dimension (in case of MNIST it is 28 * 28).\n        \"\"\"\n        super(Decoder, self).__init__()\n        self.latent_to_hidden = nn.Linear(latent_dim, hidden_dim)\n        self.hidden_to_out = nn.Linear(hidden_dim, output_dim)\n\n    def forward(self, x):\n        # x is of shape [batch_size, latent_dim]\n        hidden = func.relu(self.latent_to_hidden(x))\n        # hidden is of shape [batch_size, hidden_dim]\n        predicted = torch.sigmoid(self.hidden_to_out(hidden))\n        # predicted is of shape [batch_size, output_dim]\n\n        return predicted\n\n\nclass VAE(nn.Module):\n    \"\"\" This the VAE, which takes an encoder and a decoder.\n    \"\"\"\n    def __init__(self, input_dim, hidden_dim, latent_dim, dropout=0):\n        \"\"\" Init class.\n\n        Parameters\n        ----------\n        input_dim: int\n            the size of input (in case of MNIST 28 * 28).\n        hidden_dim: int\n            the size of hidden dimension.\n        latent_dim: int\n            the latent dimension.\n        dropout: float, default 0.1\n            the dropout rate (trick for missing data).\n        \"\"\"\n        super(VAE, self).__init__()\n        self.latent_dim = latent_dim\n        self.dropout = dropout\n        self.encorder = Encoder(input_dim, hidden_dim, latent_dim, dropout)\n        self.decorder = Decoder(latent_dim, hidden_dim, input_dim)\n\n    def reparameterization(self, mu, logvar):\n        # sample a latent vector from the latent space - using the\n        # reparameterization trick\n        std = torch.exp(0.5 * logvar)\n        eps = torch.rand_like(std)\n        return mu + eps * std\n\n    def forward(self, x):\n        # encode an image into a distribution over the latent space\n        z_mu, z_var = self.encorder(x)\n        # sample a latent vector from the latent space - using the\n        # reparameterization trick\n        # sample from the distribution having latent parameters z_mu, z_var\n        x_sample = reparameterization(z_mu, z_var)\n        # decode the latent vector \n        predicted = self.decorder(x_sample)\n\n        return predicted, {\"z_mu\": z_mu, \"z_var\": z_var}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Loss\n----\n\nVAE consists of two loss functions:\n\n1. Reconstruction loss: how well we can reconstruct the image\n2. KL divergence loss: how off the distribution over the latent space is \n   from the prior. Given the prior is a standard Gaussian and the inferred\n   distribution is a Gaussian with a diagonal covariance matrix,\n   the KL-divergence becomes analytically solvable.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class DecodeLoss(object):\n    def __init__(self, mse=False):\n        super(DecodeLoss, self).__init__()\n        self.layer_outputs = None\n        self.mse = mse\n\n    def __call__(self, x_sample, x):\n        if self.layer_outputs is None:\n            raise ValueError(\"The model needs to return the latent space \"\n                             \"distribution parameters z_mu, z_var.\")\n        z_mu = self.layer_outputs[\"z_mu\"]\n        z_var = self.layer_outputs[\"z_var\"]\n        # reconstruction loss\n        if self.mse:\n            recon_loss = func.mse_loss(x_sample, x, reduction=\"sum\")\n        else:\n            recon_loss = func.binary_cross_entropy(x_sample, x, reduction=\"sum\")\n        # KL divergence loss\n        kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu ** 2 - 1.0 - z_var)\n\n        return recon_loss + kl_loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Training\n--------\n\nWe'll train the model to optimize the losses using Adam optimizer.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def prepare_pred(y_pred):\n    y_pred = y_pred[:3]\n    y_pred = y_pred.reshape(-1, 28, 28)\n    y_pred = np.asarray([ndimage.zoom(arr, 5, order=0) for arr in y_pred])\n    y_pred = np.expand_dims(y_pred, axis=1)\n    y_pred = (y_pred / y_pred.max()) * 255\n    return y_pred\n\ndef sampling(signal):\n    \"\"\" Sample from the distribution and generate a image.\n    \"\"\"\n    device = signal.object.device\n    model = signal.object.model\n    board = signal.object.board\n    # sample and generate a image\n    z = torch.randn(1, model.latent_dim).to(device)\n    # run only the decoder\n    reconstructed_img = model.decorder(z)\n    img = reconstructed_img.view(28, 28).detach().numpy()\n    # display result\n    img = ndimage.zoom(img, 5, order=0)\n    img = np.expand_dims(img, axis=0)\n    img = np.expand_dims(img, axis=0)\n    img = (img / img.max()) * 255\n    board.viewer.images(\n        img,\n        opts={\n            \"title\": \"sampling\",\n            \"caption\": \"sampling\"},\n        win=\"sampling\")    \n\nmodel = VAE(input_dim=(28 * 28), hidden_dim=128, latent_dim=20, dropout=0.5)\ninterface = DeepLearningInterface(\n    model=model,\n    optimizer_name=\"Adam\",\n    learning_rate=0.001,\n    loss=DecodeLoss())\ninterface.board = Board(\n    port=8097, host=\"http://localhost\", env=\"vae\", display_pred=True,\n    prepare_pred=prepare_pred)\ninterface.add_observer(\"after_epoch\", update_board)\ninterface.add_observer(\"after_epoch\", sampling)\ntest_history, train_history = interface.training(\n    manager=manager,\n    nb_epochs=50,\n    checkpointdir=None,\n    fold_index=0,\n    with_validation=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Conclusion\n----------\n\nUsing a simple feed forward network (no fancy convolutions) we're able to\ngenerate nice looking images after 10 epochs. We generate\nimage randomly. We will see in a next tutorial how to add a condition on\nthe number we want to generate the image.\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}