{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nMoE-Sim-VAE\n===========\n\nCredit: A Grigis\n\nMixture of Experts VAE with similarity prior: MoE-Sim-VAE\n\nReference: Mixture-of-Experts Variational Autoencoder for Clustering and\nGenerating from Similarity-Based Representations on Single Cell Data,\nAndreas Kopf, arXiv 2020.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Imports\nimport os\nimport sys\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()\nimport numpy as np\nimport matplotlib.colors\nfrom matplotlib.lines import Line2D\nimport matplotlib.pyplot as plt\nfrom sklearn.neighbors import NearestNeighbors\nfrom sklearn.linear_model import LogisticRegression\nfrom sklearn.metrics import log_loss\nimport umap\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as func\nfrom torch.distributions import Normal, kl_divergence\nimport pynet\nfrom pynet import NetParameters\nfrom pynet.datasets import DataManager, fetch_minst\nfrom pynet.interfaces import MOESimVAENetEncoder\nfrom pynet.plotting import Board, update_board"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Parameters\n----------\n\nDefine some global parameters that will be used to create and train the\nmodel:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "random_state = 42\ndatasetdir = \"/neurospin/nsap/datasets/minst\"\ncheckpointdir = None\ninput_dim = 28 * 28\nn_components_umap = 2\nn_neighbors_knn = 10\nbatch_size = 128\nn_epochs = 10 #20000\nlearning_rate = 0.0001\ndropout_rate = 0.5\nlatent_dim = 68\nn_experts = 10\nbeta = 1.\nalpha = 1.\ndevice = torch.device(\"cpu\") #torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nlosses = pynet.get_tools(tool_name=\"losses\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "MNIST dataset\n-------------\n\nThe model will be trained on MNIST - handwritten digits dataset. The input\nis an image in R(28\u00d728):\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def flatten(arr):\n    return arr.flatten()\n\ndata = fetch_minst(datasetdir=datasetdir)\nmanager = DataManager(\n    input_path=data.input_path,\n    metadata_path=data.metadata_path,\n    stratify_label=\"label\",\n    labels=\"label\",\n    number_of_folds=10,\n    batch_size=batch_size,\n    test_size=0,\n    input_transforms=[flatten],\n    add_input=True,\n    sample_size=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Data driven similarity matrix\n-----------------------------\n\nThe similarity matrix is derived in an unsupervised way (eg, UMAP\nprojection of the data and k nearest neighbors or distance thresholding to\ndefine the adjacency matrix for the batch), but can also be used to include\nweakly supervised information (eg, knowledge about diseased vs\nnon diseased patients). The similarity feature in MoE Sim VAE can also be\nused to include prior knowledge about the best similarity measure on the\ndata.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "data = manager.inputs[:batch_size]\nlabels = manager.labels[:batch_size]\nsimilarity, embedding = losses[\"MOESimVAELoss\"].get_similarity_matrix(\n    data, n_components_umap, n_neighbors_knn, random_state=random_state)\nprint(\"-- umap embedding:\", embedding.shape)\nprint(\"-- similarity:\", similarity.shape)\nfig, ax_array = plt.subplots(10, 10)\naxes = ax_array.flatten()\nfor idx, ax in enumerate(axes):\n    ax.imshow(data[idx, 0], cmap=\"gray_r\")\nplt.setp(axes, xticks=[], yticks=[], frame_on=False)\nplt.tight_layout(h_pad=0.5, w_pad=0.01)\nplt.figure()\nplt.scatter(embedding[:, 0], embedding[:, 1], c=labels, cmap=\"Spectral\", s=5)\nplt.gca().set_aspect(\"equal\", \"datalim\")\nplt.colorbar(boundaries=(np.arange(11) - 0.5)).set_ticks(np.arange(10))\nplt.axis(\"off\")\nplt.title(\"UMAP projection of the dataset\", fontsize=10)\nplt.figure()\ncmap = matplotlib.colors.ListedColormap([\"white\", \"orange\"])\nplt.imshow(similarity, cmap=cmap)\nplt.axis(\"off\")\nplt.title(\"K-nearest-neighbors graph\", fontsize=10)\nplt.colorbar(boundaries=(np.arange(3) - 0.5)).set_ticks(np.arange(2))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Similarity loss\n---------------\n\nReconstruct a data-driven clustering loss. We use the use case proposed in\n'Understanding binary cross-entropy / log loss: a visual explanation'\nhttps://towardsdatascience.com/understanding-binary-cross-entropy-log-loss-\na-visual-explanation-a3ac6025181a\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x = np.array([-2.2, -1.4, -.8, .2, .4, .8, 1.2, 2.2, 2.9, 4.6])\ny = np.array([0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n\ncustom_lines = [Line2D([0], [0], color=\"red\", lw=4),\n                Line2D([0], [0], color=\"green\", lw=4),\n                Line2D([0], [0], color=\"blue\", lw=4)]\n\nlogr = LogisticRegression(solver=\"lbfgs\")\nlogr.fit(x.reshape(-1, 1), y)\ny_pred = logr.predict_proba(x.reshape(-1, 1))[:, 1].ravel()\nprob = y_pred.copy()\nprob[y == 0.] = 1 - prob[y == 0.]\nloss = log_loss(y, y_pred)\nprint(\"x = {}\".format(x))\nprint(\"y = {}\".format(y))\nprint(\"p(y) = {}\".format(np.round(y_pred, 2)))\nprint(\"Log Loss / Cross Entropy = {:.4f}\".format(loss))\n\nfig, ax = plt.subplots()\ncolors = [\"red\" if yi == 0. else \"green\" for yi in y]\nax.bar(x, -np.log(prob), width=0.1, color=colors, alpha=0.5)\nax.axhline(y=loss, color=\"black\", linestyle=\"--\")\nax.scatter(x, [-0.05 if yi == 0. else -1.15 for yi in y], color=colors,\n           edgecolors=\"black\", s=40, marker=\"o\", alpha=0.5)\nax.plot(x, y_pred - 1.1, color=\"blue\")\nax.bar(x[y == 1.], y_pred[y == 1.], width=0.1, bottom=-1.1, color=\"green\",\n       alpha=0.5)\nax.bar(x[y == 0.], 1 - y_pred[y == 0.], width=0.1,\n       bottom=-(1. - y_pred[y == 0.] + 0.1), color=\"red\", alpha=0.5)\nax.text(0.5, 0.5, \"{:.4f}\".format(loss))\nax.set_title(\"Binary Cross Entropy\", fontsize=10)\nax.text(-0.45, 0.75, \"-log(p)\")\nax.text(-0.2, -0.3, \"p\")\nax.spines[\"left\"].set_position(\"zero\")\nax.spines[\"right\"].set_color(\"none\")\nax.yaxis.tick_left()\nax.spines[\"bottom\"].set_position(\"zero\")\nax.spines[\"top\"].set_color(\"none\")\nax.xaxis.tick_bottom()\nax.grid(True, which=\"both\")\nax.legend(custom_lines, [\"Negative\", \"Positive\", \"Sigmoid\"])\n\nprobs_true = torch.zeros((4, 2))\nprobs_true[:2, 0] = 1\nprobs_true[2:, 1] = 1\nsimilarity = torch.mm(probs_true, torch.transpose(probs_true, 0, 1))\nfactors = np.linspace(0.1, 1, 10)\nsim_losses = []\nce_losses = []\nprint(similarity)\ndef cross_entropy(predictions, targets):\n    N = predictions.shape[0]\n    ce = -np.sum(targets * np.log(predictions)) / N\n    return ce\nfor factor in factors:\n    probs = probs_true * factor\n    probs[:2, 1] = 1 - factor\n    probs[2:, 0] = 1 - factor\n    predictions = torch.mm(probs, torch.transpose(probs, 0, 1))\n    print(predictions)\n    print(cross_entropy(predictions.numpy(), similarity.numpy()))\n    _loss = losses[\"MOESimVAELoss\"].similarity(probs, similarity)\n    _loss = torch.mean(torch.sum(_loss, dim=1), dim=0)\n    _ce_loss = log_loss(probs_true[:, 0], probs[:, 0])\n    if np.isnan(_ce_loss):\n        _ce_loss = 0.\n    print(probs)\n    print(_loss, _ce_loss)\n    sim_losses.append(_loss.cpu().numpy())\n    ce_losses.append(_ce_loss)\nfig, ax = plt.subplots()\nax.plot(factors, sim_losses, color=\"blue\", label=\"SIM\")\nax.plot(factors, ce_losses, color=\"green\", label=\"CE\")\nax.set_title(\"SIMILARITY losses\", fontsize=10)\nax.set_xlabel(\"factors\")\nax.grid(True, which=\"both\")\nax.legend()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "DEPICT loss\n-----------\n\nThe DEPICT loss encourages the model to learn invariant features\nfrom the latent representation for clustering with respect to noise.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "probs = torch.ones((1, 10))\nfactors = np.linspace(0.1, 1, 10)\ndepict_losses = []\nfor factor in factors:\n    probs_noisy = torch.ones((1, 10)) * factor\n    _loss = losses[\"MOESimVAELoss\"].depict(probs, probs_noisy).mean()\n    depict_losses.append(_loss.cpu().numpy())\nfig, ax = plt.subplots()\nax.plot(factors, depict_losses)\nax.set_title(\"DEPICT losses\", fontsize=10)\nax.set_xlabel(\"factors\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The Model\n---------\n\nThe model is a VAE with a Gaussian Mixture Prior (GMP) and N independent\ndecoder path:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "params = NetParameters(\n    input_dim=input_dim,\n    latent_dim=latent_dim,\n    n_mix_components=n_experts,\n    dense_hidden_dims=[256],\n    classifier_hidden_dims=[100],\n    sigma_min=0.001,\n    raw_sigma_bias=0.25,\n    gen_bias_init=0,\n    dropout=0.5)\ninterface = MOESimVAENetEncoder(\n    params,\n    optimizer_name=\"Adam\",\n    learning_rate=learning_rate,\n    loss=losses[\"MOESimVAELoss\"](\n        beta=beta, alpha=alpha, n_components_umap=n_components_umap,\n        n_neighbors_knn=n_neighbors_knn),\n    use_cuda=(device.type != \"cpu\"))\nprint(interface.model)\ninterface.board = Board(\n    port=8097, host=\"http://localhost\", env=\"moevae\")\ninterface.add_observer(\"after_epoch\", update_board)\ntrain_history, valid_history = interface.training(\n    manager=manager,\n    nb_epochs=n_epochs,\n    checkpointdir=checkpointdir,\n    save_after_epochs=100,\n    fold_index=0,\n    with_validation=False)\nprint(train_history)\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}