{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nConditional Variational AutoEncoder (VAE)\n=========================================\n\nCredit: A Grigis\n\nBased on:\n\n- https://ravirajag.dev\n\nThis tutorial is for the intuition of simple Variational Autoencoder(VAE)\nimplementation in pynet.\nAfter reading this tutorial, you'll understand the technical details needed to\nimplement conditional VAE.\nThe main difference from the vanilla VAE is, in the vanilla case we generate\nimage randomly, and here we can condition for which number we want to generate\nthe image.\n\nLet\u2019s begin with importing stuffs:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()\n\nimport numpy as np\nfrom scipy import ndimage\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as func\nfrom pynet.datasets import DataManager, fetch_minst\nfrom pynet.interfaces import DeepLearningInterface\nfrom pynet.plotting import Board, update_board"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The model will be trained on MNIST - handwritten digits dataset. The input\nis an image in R(28\u00d728).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def flatten(arr):\n    return arr.flatten()\n\ndata = fetch_minst(datasetdir=\"/neurospin/nsap/datasets/minst\")\nmanager = DataManager(\n    input_path=data.input_path,\n    metadata_path=data.metadata_path,\n    labels=[\"label\"],\n    stratify_label=\"label\",\n    number_of_folds=10,\n    batch_size=64,\n    test_size=0,\n    input_transforms=[flatten],\n    add_input=True,\n    sample_size=0.5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The Model\n---------\n\nThe model is composed of three sub-networks:\n\n1. Given x (image), encode it into a distribution over the latent space -\n   referred to as Q(z|x).\n2. Given z in latent space (code representation of an image), decode it into\n   the image it represents - referred to as f(z).\n3. Given x, classify its digit by mapping it to a layer of size 10 where the\n   i'th value contains the probability of the i'th digit.\n\nThe first two sub-networks are the vanilla VAE framework.\n\nThe third one is used as an auxiliary task, which will enforce some of the\nlatent dimensions to encode the digit found in an image. \nIn the vanilla VAE case we don't care what\ninformation each dimension of the latent space holds. The model can learn\nto encode whatever information it finds valuable for its task. Since we're\nfamiliar with the dataset, we know the digit type should be important.\nWe want to help the model by providing it with this information. Moreover,\nwe'll use this information to generate images conditioned on the digit type,\nas we will see later.\n\nGiven the digit type, we'll encode it using one hot encoding, that is, a\nvector of size 10. These 10 numbers will be concatenated into the latent \nvector, so when decoding that vector into an image, the model will make use\nof the digit information.\n\nThere are two ways to provide the model with a one hot encoding vector:\n\n1. Add it as an input to the model.\n2. Add it as a label so the model will have to predict it by itself: add\n   another sub-network that predicts a vector of size 10 where the loss is\n   the cross entropy with the expected one hot vector.\n\nWe'll go with the second option. Why? Well, in test time we can use the model\nin two ways:\n\n1. Provide an image as input, and infer a latent vector.\n2. Provide a latent vector as input, and generate an image.\n\nSince we want to support the first option too, we can't provide the model\nwith the digit as input, since we won't know it in test time. Hence, the\nmodel must learn to predict it.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def idx2onehot(idx, n):\n    \"\"\" Given a class label, we will convert it into one-hot encoding.\n    \"\"\"\n    assert idx.ndim == 1\n    assert torch.max(idx).item() < n\n    idx = idx.view(-1, 1)\n    onehot = torch.zeros(idx.size(0), n)\n    onehot.scatter_(1, idx.data, 1)\n\n    return onehot\n\nclass Encoder(nn.Module):\n    \"\"\" This the encoder part of VAE.\n    \"\"\"\n    def __init__(self, input_dim, hidden_dim, z_dim):\n        \"\"\" Init class.\n\n        Parameters\n        ----------\n        input_dim: int\n            the size of input (in case of MNIST 28 * 28).\n        hidden_dim: int\n            the size of hidden dimension.\n        z_dim: int\n            the latent dimension.\n        \"\"\"\n        super().__init__()\n        self.linear = nn.Linear(input_dim, hidden_dim)\n        self.mu = nn.Linear(hidden_dim, z_dim)\n        self.var = nn.Linear(hidden_dim, z_dim)\n\n    def forward(self, x):\n        # x is of shape [batch_size, input_dim]\n        hidden = func.relu(self.linear(x))\n        # hidden is of shape [batch_size, hidden_dim]\n        z_mu = self.mu(hidden)\n        # z_mu is of shape [batch_size, latent_dim]\n        z_var = self.var(hidden)\n        # z_var is of shape [batch_size, latent_dim]: this is log(var)\n\n        return z_mu, z_var\n\nclass Decoder(nn.Module):\n    \"\"\" This the decoder part of VAE\n    \"\"\"\n    def __init__(self, z_dim, hidden_dim, output_dim, n_classes):\n        \"\"\" Init class.\n\n        Parameters\n        ----------\n        z_dim: int\n            the latent size.\n        hidden_dim: int\n            the size of hidden dimension.\n        output_dim: int\n            the output dimension (in case of MNIST it is 28 * 28).\n        n_classes: int\n            the number of classes (dimension of one-hot representation of   \n            labels).\n        \"\"\"\n        super().__init__()\n        self.latent_to_hidden = nn.Linear(z_dim + n_classes, hidden_dim)\n        self.hidden_to_out = nn.Linear(hidden_dim, output_dim)\n\n    def forward(self, x):\n        # x is of shape [batch_size, latent_dim]\n        hidden = func.relu(self.latent_to_hidden(x))\n        # hidden is of shape [batch_size, hidden_dim]\n        predicted = torch.sigmoid(self.hidden_to_out(hidden))\n        # predicted is of shape [batch_size, output_dim]\n\n        return predicted\n\nclass DigitClassifier(nn.Module):\n    \"\"\" This the digit classifer part of VAE\n    \"\"\"\n    def __init__(self, input_dim, hidden_dim, n_classes):\n        \"\"\" Init class.\n\n        Parameters\n        ----------\n        input_dim: int\n            the size of input (in case of MNIST 28 * 28).\n        hidden_dim: int\n            the size of hidden dimension.\n        n_classes: int\n            the number of classes (dimension of one-hot representation of   \n            labels).\n        \"\"\"\n        super().__init__()\n        self.linear1 = nn.Linear(input_dim, hidden_dim)\n        self.linear2 = nn.Linear(hidden_dim, n_classes)\n\n    def forward(self, x):\n        # x is of shape [batch_size, input_dim]\n        hidden = func.relu(self.linear1(x))\n        # hidden is of shape [batch_size, hidden_dim]\n        logits = self.linear2(hidden)\n        # logits is of shape [batch_size, n_classes]\n\n        return logits\n\nclass CVAE(nn.Module):\n    \"\"\" This is the conditional VAE.\n    \"\"\"\n    def __init__(self, input_dim, hidden_dim, latent_dim, n_classes):\n        \"\"\" Init class.\n\n        Parameters\n        ----------\n        input_dim: int\n            the size of input (in case of MNIST 28 * 28).\n        hidden_dim: int\n            the size of hidden dimension.\n        z_dim: int\n            the latent dimension.\n        n_classes: int\n            the number of classes (dimension of one-hot representation of   \n            labels).\n        \"\"\"\n        super(CVAE, self).__init__()\n        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)\n        self.decoder = Decoder(latent_dim, hidden_dim, input_dim, n_classes)\n        self.classifier = DigitClassifier(input_dim, hidden_dim, n_classes)\n\n    def forward(self, x):\n        # encode an image into a distribution over the latent space\n        z_mu, z_var = self.encoder(x)\n\n        # sample a latent vector from the latent space - using the\n        # reparameterization trick\n        # sample from the distribution having latent parameters z_mu, z_var\n        # the reason we exponentiate is because we need the variance to be\n        # positive. Any activation function whose range is the positive numbers\n        # could be used here.\n        std = torch.exp(z_var / 2)\n        eps = torch.randn_like(std)\n        x_sample = eps.mul(std).add_(z_mu)\n\n        # classify the digit\n        logits = self.classifier(x)\n        y = func.gumbel_softmax(logits, hard=True)\n\n        # decode the latent vector - concatenated to the digits\n        # classification into an image\n        predicted = self.decoder(torch.cat([x_sample, y], dim=1))\n\n        return predicted, {\"z_mu\": z_mu, \"z_var\": z_var, \"logits\": logits,\n                           \"y\": y}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Loss\n----\n\nVAE consists of three loss functions:\n\n1. Reconstruction loss: how well we can reconstruct the image\n2. KL divergence: how off the distribution over the latent space is \n   from the prior. Given the prior is a standard Gaussian and the inferred\n   distribution is a Gaussian with a diagonal covariance matrix,\n   the KL-divergence becomes analytically solvable.\n3. Classification: how we predict classes. A classification weight is used\n   to weight between the two losses, since there's a tension between them.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class DecodeLoss(object):\n    def __init__(self, classification_weight):\n        super(DecodeLoss, self).__init__()\n        self.layer_outputs = None\n        self.classification_weight = classification_weight\n        self.auto_encode = VAELoss()\n        self.classification = ClassificationLoss()\n\n    def __call__(self, x_sample, x, y):\n        if self.layer_outputs is None:\n            raise ValueError(\"The model needs to return the latent space \"\n                             \"distribution parameters z_mu, z_var.\")\n        self.auto_encode.layer_outputs = self.layer_outputs\n        self.classification.layer_outputs = self.layer_outputs\n        loss = (self.auto_encode(x_sample, x, y) + self.classification_weight *\n                self.classification(x_sample, x, y))\n\n        return loss\n\nclass VAELoss(object):\n    def __init__(self):\n        super(VAELoss, self).__init__()\n        self.layer_outputs = None\n\n    def __call__(self, x_sample, x, y):\n        if self.layer_outputs is None:\n            raise ValueError(\"The model needs to return the latent space \"\n                             \"distribution parameters z_mu, z_var.\")\n        z_mu = self.layer_outputs[\"z_mu\"]\n        z_var = self.layer_outputs[\"z_var\"]\n        # reconstruction loss\n        recon_loss = func.binary_cross_entropy(x_sample, x, reduction=\"sum\")\n        # KL divergence loss\n        kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1.0 - z_var)\n        # encoder loss\n        loss_auto_encode = (recon_loss + kl_loss) / x_sample.size(0)\n\n        return loss_auto_encode\n\nclass ClassificationLoss(object):\n    def __init__(self):\n        super(ClassificationLoss, self).__init__()\n        self.layer_outputs = None\n\n    def __call__(self, x_sample, x, y):\n        if self.layer_outputs is None:\n            raise ValueError(\"The model needs to return the latent space \"\n                             \"distribution parameters z_mu, z_var.\")\n        logits = self.layer_outputs[\"logits\"]\n        # classification\n        criterion = nn.CrossEntropyLoss(reduction=\"mean\")\n        loss_classification = criterion(logits, y)\n\n        return loss_classification"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Training\n--------\n\nWe'll train the model to optimize the losses - the VAE loss and the\nclassification loss - using Adam optimizer.\n\nAt the end of every epoch we'll sample latent vectors and decode them into\nimages, so we can visualize how the generative power of the model improves\nover the epochs. The sampling method is as follows:\n\n1. Deterministically set the dimensions which are used for digit\n   classification according to the digit we want to generate an image for.\n   If for example we want to generate an image of the digit 2, these\n   dimensions will be set to [0010000000].\n2. Randomly sample the other dimensions according to the prior - a\n   multivariate Gaussian. We'll use these sampled values for all the\n   different digits we generate in a given epoch. This way we can have a\n   feeling of what is encoded in the other dimensions, for example stroke\n   style.\n\nThe intuition behind step 1 is that after convergence the model should be\nable to classify the digit in an input image using these dimensions. On the\nother hand, these dimensions are also used in the decoding step to generate\nthe image. It means the decoder sub-network learns that when these\ndimensions have the values corresponding to the digit 2, it should generate\nan image of that digit. Therefore, if we manually set these dimensions to\ncontain the information of the digit 2, we'll get a generated image of that\ndigit.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def prepare_pred(y_pred):\n    y_pred = y_pred[:3]\n    y_pred = y_pred.reshape(-1, 28, 28)\n    y_pred = np.asarray([ndimage.zoom(arr, 5, order=0) for arr in y_pred])\n    y_pred = np.expand_dims(y_pred, axis=1)\n    y_pred = (y_pred / y_pred.max()) * 255\n    return y_pred\n\ndef sampling(signal):\n    \"\"\" Sample from the distribution and generate a image.\n    \"\"\"\n    device = signal.object.device\n    model = signal.object.model\n    board = signal.object.board\n    # sample and generate a image\n    z = torch.randn(1, 20).to(device).repeat(10, 1)\n    y = torch.eye(10).to(device, dtype=z.dtype)\n    z = torch.cat((z, y), dim=1)\n    # run only the decoder\n    reconstructed_img = model.decoder(z)\n    img = reconstructed_img.view(-1, 28, 28).detach().numpy()\n    # display result\n    img = np.asarray([ndimage.zoom(arr, 5, order=0) for arr in img])\n    img = np.expand_dims(img, axis=1)\n    img = (img / img.max()) * 255\n    board.viewer.images(\n        img,\n        opts={\n            \"title\": \"sampling\",\n            \"caption\": \"sampling\"},\n        win=\"sampling\")    \n\nmodel = CVAE(input_dim=(28 * 28), hidden_dim=128, latent_dim=20, n_classes=10)\ninterface = DeepLearningInterface(\n    model=model,\n    optimizer_name=\"Adam\",\n    learning_rate=0.001,\n    metrics=[ClassificationLoss(), VAELoss()],\n    loss=DecodeLoss(classification_weight=10.))\ninterface.board = Board(\n    port=8097, host=\"http://localhost\", env=\"vae\", display_pred=True,\n    prepare_pred=prepare_pred)\ninterface.add_observer(\"after_epoch\", update_board)\ninterface.add_observer(\"after_epoch\", sampling)\ntest_history, train_history = interface.training(\n    manager=manager,\n    nb_epochs=20,\n    checkpointdir=None,\n    fold_index=0,\n    with_validation=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Conclusion\n----------\n\nUsing a simple feed forward network (no fancy convolutions) we're able to\ngenerate nice looking images after 10 epochs. The model learned to use the\nspecial digit dimensions quite fast and we already see the sequence of\ndigits we were trying to generate.\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}