{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nUnsupervised MoE Variational AutoEncoder (VAE)\n==============================================\n\nCredit: A Grigis\n\nBased on:\n\n- https://towardsdatascience.com/mixture-of-variational-autoencoders-\n  a-fusion-between-moe-and-vae-22c0901a6675\n\nThe Variational Autoencoder (VAE) is a  neural networks that try to learn the\nshape of the input space. Once trained, the model can be used to generate\nnew samples from the input space.\n\nIf we have labels for our input data, it\u2019s also possible to condition the\ngeneration process on the label. The idea here is to achieve the same results\nusing an unsupervised approach.\n\nMixture of Experts\n------------------\n\nMoE is a supervised learning framework. MoE relies on the possibility that the\ninput might be segmented according to the x->y mapping. How can we train a\nmodel that learns the split points while at the same time learns the mapping\nthat defines the split points.\n\nMoE does so using an architecture of multiple subnetworks - one manager and\nmultiple experts. The manager maps the input into a soft decision over the\nexperts, which is used in two contexts:\n\n1. The output of the network is a weighted average of the experts' outputs,\n   where the weights are the manager's output.\n2. The loss function is $\\sum_i p_i(y - \bar{y_i})^2$. y is the label,\n   $\bar{y_i}$ is the output of the i'th expert, $p_i$ is the i'th entry of\n   the manager's output. When you differentiate the loss, you get these\n   results: a) the manager decides for each expert how much it contributes to\n   the loss. In other words, the manager chooses which experts should tune\n   their weights according to their error, and b) the manager tunes the\n   probabilities it outputs in such a way that the experts that got it right\n   will get higher probabilities than those that didn\u2019t. This loss function\n   encourages the experts to specialize in different kinds of inputs.\n\nMoE is a framework for supervised learning. Surely we can change y to be x for\nthe unsupervised case, right? MoE's power stems from the fact that each expert\nspecializes in a different segment of the input space with a unique mapping\nx ->y. If we use the mapping x->x, each expert will specialize in a different\nsegment of the input space with unique patterns in the input itself.\n\nWe'll use VAEs as the experts. Part of the VAE\u2019s loss is the reconstruction\nloss, where the VAE tries to reconstruct the original input image x.\n\nA cool byproduct of this architecture is that the manager can classify the\ndigit found in an image using its output vector!\n\nOne thing we need to be careful about when training this model is that the\nmanager could easily degenerate into outputting a constant vector -\nregardless of the input in hand. This results in one VAE specialized in all\ndigits, and nine VAEs specialized in nothing. One way to mitigate it, which\nis described in the MoE paper, is to add a balancing term to the loss.\nIt encourages the outputs of the manager over a batch of inputs to\nbe balanced: $\\sum_\text{examples in batch} \u000bec{p} \u0007pprox Uniform$.\n\nLet's begin with importing stuffs:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()\n\nimport numpy as np\nfrom scipy import ndimage\nimport matplotlib.pyplot as plt\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as func\nfrom torch.distributions import Normal, kl_divergence\nfrom pynet.datasets import DataManager, fetch_minst\nfrom pynet.interfaces import DeepLearningInterface\nfrom pynet.plotting import Board, update_board"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The model will be trained on MNIST - handwritten digits dataset. The input\nis an image in R(28\u00d728).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def flatten(arr):\n    return arr.flatten()\n\ndata = fetch_minst(datasetdir=\"/neurospin/nsap/datasets/minst\")\nmanager = DataManager(\n    input_path=data.input_path,\n    metadata_path=data.metadata_path,\n    stratify_label=\"label\",\n    number_of_folds=10,\n    batch_size=100,\n    test_size=0,\n    input_transforms=[flatten],\n    add_input=True,\n    sample_size=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The Model\n---------\n\nThe model is composed of two sub-networks:\n\n1. Given x (image), encode it into a distribution over the latent space -\n   referred to as Q(z|x).\n2. Given z in latent space (code representation of an image), decode it into\n   the image it represents - referred to as f(z).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class Encoder(nn.Module):\n    \"\"\" The encoder part of VAE.\n    \"\"\"\n    def __init__(self, input_dim, hidden_dim, latent_dim):\n        \"\"\" Init class.\n\n        Parameters\n        ----------\n        input_dim: int\n            the size of input (in case of MNIST 28 * 28).\n        hidden_dim: int\n            the size of hidden dimension.\n        latent_dim: int\n            the latent dimension.\n        \"\"\"\n        super().__init__()\n        self.linear = nn.Linear(input_dim, hidden_dim)\n        self.mu = nn.Linear(hidden_dim, latent_dim)\n        self.logvar = nn.Linear(hidden_dim, latent_dim)\n\n    def forward(self, x):\n        hidden = torch.sigmoid(self.linear(x))\n        z_mu = self.mu(hidden)\n        z_logvar = self.logvar(hidden)\n        return z_mu, z_logvar\n\n\nclass Decoder(nn.Module):\n    \"\"\" The decoder part of VAE\n    \"\"\"\n    def __init__(self, latent_dim, hidden_dim, output_dim):\n        \"\"\" Init class.\n\n        Parameters\n        ----------\n        latent_dim: int\n            the latent size.\n        hidden_dim: int\n            the size of hidden dimension.\n        output_dim: int\n            the output dimension (in case of MNIST it is 28 * 28).\n        \"\"\"\n        super().__init__()\n        self.latent_to_hidden = nn.Linear(latent_dim, hidden_dim)\n        self.hidden_to_out = nn.Linear(hidden_dim, output_dim)\n\n    def forward(self, x):\n        hidden = torch.sigmoid(self.latent_to_hidden(x))\n        predicted = torch.sigmoid(self.hidden_to_out(hidden))\n        return predicted\n\n\nclass VAE(nn.Module):\n    \"\"\" This is the VAE.\n    \"\"\"\n    def __init__(self, input_dim, hidden_dim, latent_dim):\n        \"\"\" Init class.\n\n        Parameters\n        ----------\n        input_dim: int\n            the size of input (in case of MNIST 28 * 28).\n        hidden_dim: int\n            the size of hidden dimension.\n        latent_dim: int\n            the latent dimension.\n        \"\"\"\n        super(VAE, self).__init__()\n        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)\n        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)\n\n    def forward(self, x):\n        # encode an image into a distribution over the latent space\n        z_mu, z_logvar = self.encoder(x)\n\n        # sample a latent vector from the latent space - using the\n        # reparameterization trick\n        # sample from the distribution having latent parameters z_mu, z_var\n        z_var = torch.exp(z_logvar) + 1e-5\n        std = torch.sqrt(z_var)\n        eps = torch.randn_like(std)\n        x_sample = eps.mul(std).add_(z_mu)\n\n        # decode the latent vector\n        predicted = self.decoder(x_sample)\n\n        return predicted, {\"z_mu\": z_mu, \"z_var\": z_var}\n\n\nclass VAELoss(object):\n    def __init__(self, use_distributions=True):\n        super(VAELoss, self).__init__()\n        self.layer_outputs = None\n        self.use_distributions = use_distributions\n\n    def __call__(self, x_sample, x):\n        if self.layer_outputs is None:\n            raise ValueError(\"The model needs to return the latent space \"\n                             \"distribution parameters z_mu, z_var.\")\n        if self.use_distributions:\n            p = x_sample\n            q = self.layer_outputs[\"q\"]\n        else:\n            z_mu = self.layer_outputs[\"z_mu\"]\n            z_var = self.layer_outputs[\"z_var\"]\n            p = Normal(x_sample, 0.5)\n            q = Normal(z_mu, z_var.pow(0.5))\n\n        # reconstruction loss: log likelihood\n        ll_loss = - p.log_prob(x).sum(-1, keepdim=True)\n        # regularization loss: KL divergence\n        kl_loss = kl_divergence(q, Normal(0, 1)).sum(-1, keepdim=True)\n\n        combined_loss = ll_loss + kl_loss\n\n        return combined_loss, {\"ll_loss\": ll_loss, \"kl_loss\": kl_loss}\n\n\nclass Manager(nn.Module):\n    def __init__(self, input_dim, hidden_dim, experts, latent_dim,\n                 log_alpha=None):\n        \"\"\" Init class.\n\n        Parameters\n        ----------\n        input_dim: int\n            the size of input (in case of MNIST 28 * 28).\n        hidden_dim: int\n            the size of hidden dimension.\n        experts: list of VAE\n            the manager experts.\n        \"\"\"\n        super(Manager, self).__init__()\n        self._experts = nn.ModuleList(experts)\n        self.latent_dim = latent_dim\n        self._experts_results = []\n        self.linear1 = nn.Linear(input_dim, hidden_dim)\n        self.linear2 = nn.Linear(hidden_dim, len(experts))\n        \n    def forward(self, x):\n        hidden = torch.sigmoid(self.linear1(x))\n        logits = self.linear2(hidden)\n        probs = func.softmax(logits)\n        self._experts_results = []\n        for net in self._experts:\n            self._experts_results.append(net(x))\n        return probs, {\"experts_results\": self._experts_results}\n\n\nclass ManagerLoss(object):\n    def __init__(self, balancing_weight=0.1):\n        \"\"\" Init class.\n\n        Parameters\n        ----------\n        balancing_weight: float, default 0.1    \n            how much the balancing term will contribute to the loss.\n        \"\"\"\n        super(ManagerLoss, self).__init__()\n        self.layer_outputs = None\n        self.balancing_weight = balancing_weight\n        self.criterion = VAELoss(use_distributions=False)\n\n    def __call__(self, probs, x):\n        if self.layer_outputs is None:\n            raise ValueError(\"The model needs to return the latent space \"\n                             \"distribution parameters z_mu, z_var.\")\n        losses = []\n        for result in self.layer_outputs[\"experts_results\"]:\n            self.criterion.layer_outputs = result[1]\n            loss, extra_loss = self.criterion(result[0], x)\n            losses.append(loss.view(-1, 1))\n        losses = torch.cat(losses, dim=1)\n        expected_expert_loss = torch.mean(\n            torch.sum(losses * probs, dim=1), dim=0)\n        experts_importance = torch.sum(probs, dim=0)\n        # Remove effect of Bessel correction\n        experts_importance_std = experts_importance.std(dim=0, unbiased=False)\n        balancing_loss = torch.pow(experts_importance_std, 2)\n        combined_loss = (\n            expected_expert_loss + self.balancing_weight * balancing_loss)\n\n        return combined_loss, {\"expected_expert_loss\": expected_expert_loss,\n                               \"balancing_loss\": balancing_loss}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Training\n--------\n\nWe'll train the model to optimize the losses using Adam optimizer.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def sampling(signal):\n    \"\"\" Sample from the distribution and generate a image.\n    \"\"\"\n    device = signal.object.device\n    experts = signal.object.model._experts\n    latent_dim = signal.object.model.latent_dim\n    board = signal.object.board\n    # sample and generate a image\n    z = torch.randn(1, latent_dim).to(device)\n    # run only the decoder\n    images = []\n    for model in experts:\n        model.eval()\n        with torch.no_grad():\n            reconstructed_img = model.decoder(z)\n            img = reconstructed_img.view(-1, 28, 28).cpu().detach().numpy()\n            img = np.asarray([ndimage.zoom(arr, 5, order=0) for arr in img])        \n            images.append(img)\n    # display result\n    images = np.asarray(images)\n    images = (images / images.max()) * 255\n    board.viewer.images(\n        images,\n        opts={\n            \"title\": \"sampling\",\n            \"caption\": \"sampling\"},\n        win=\"sampling\")    \n\nlatent_dim = 20\nexperts = [\n    VAE(input_dim=(28 * 28), hidden_dim=128, latent_dim=latent_dim)\n    for idx in range(10)]\nmodel = Manager(input_dim=(28 * 28), hidden_dim=128, experts=experts,\n                latent_dim=latent_dim)\ninterface = DeepLearningInterface(\n    model=model,\n    optimizer_name=\"Adam\",\n    learning_rate=0.001,\n    loss=ManagerLoss(balancing_weight=0.1),\n    use_cuda=True)\ninterface.board = Board(\n    port=8097, host=\"http://localhost\", env=\"vae\")\ninterface.add_observer(\"after_epoch\", update_board)\ninterface.add_observer(\"after_epoch\", sampling)\ntest_history, train_history = interface.training(\n    manager=manager,\n    nb_epochs=100,\n    checkpointdir=None,\n    fold_index=0,\n    with_validation=False)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}