{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nPathway factorized latent space\n===============================\n\nCredit: A Grigis\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Imports\nimport os\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()\nimport shutil\nimport subprocess\nfrom itertools import product\nimport anndata\nimport numpy as np\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom matplotlib.colors import rgb2hex\nfrom matplotlib.patches import Patch\nfrom sklearn.manifold import TSNE\nfrom umap import UMAP\nimport torch\nimport pynet\nfrom pynet import NetParameters\nfrom pynet.datasets import DataManager, fetch_kang\nfrom pynet.interfaces import PMVAEEncoder\nfrom pynet.utils import setup_logging"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Parameters\n----------\n\nDefine some global parameters that will be used to create and train the\nmodel:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "datasetdir = \"/neurospin/nsap/datasets/kang\"\nbatch_size = 256\nlatent_dim = 4\nnb_epochs = 1201\nlearning_rate = 0.001\nbeta = 1e-5\ncheckpointdir = \"/neurospin/nsap/datasets/kang/checkpoints\"\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nlosses = pynet.get_tools(tool_name=\"losses\")\nsetup_logging(level=\"info\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Kang dataset\n------------\n\nFetch & load the Kang dataset:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "data, trainset, testset, membership_mask = fetch_kang(\n    datasetdir=datasetdir, random_state=0)\ngtpath = os.path.join(datasetdir, \"kang_recons.h5ad\")\nmanager = DataManager.from_numpy(\n    train_inputs=trainset, validation_inputs=testset, test_inputs=data.X,\n    batch_size=batch_size, sampler=\"random\", add_input=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Training\n--------\n\nCreate/train the model:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "if checkpointdir is not None:\n    weights_filename = os.path.join(\n        checkpointdir, \"model_0_epoch_{0}.pth\".format(nb_epochs - 1))\nparams = NetParameters(\n    membership_mask=membership_mask,\n    latent_dim=latent_dim,\n    hidden_layers=[12],\n    add_auxiliary_module=True,\n    terms=membership_mask.index,\n    activation=None)\nif checkpointdir is not None and os.path.isfile(weights_filename):\n    model = PMVAEEncoder(\n        params,\n        optimizer_name=\"Adam\",\n        learning_rate=learning_rate,\n        loss=losses[\"PMVAELoss\"](beta=beta),\n        use_cuda=(device.type != \"cpu\"),\n        pretrained=weights_filename)\n    print(model.model)\nelse:\n    model = PMVAEEncoder(\n        params,\n        optimizer_name=\"Adam\",\n        learning_rate=learning_rate,\n        loss=losses[\"PMVAELoss\"](beta=beta),\n        use_cuda=(device.type != \"cpu\"))\n    print(model.model)\n    train_history, valid_history = model.training(\n        manager=manager,\n        nb_epochs=nb_epochs,\n        checkpointdir=checkpointdir,\n        save_after_epochs=100,\n        fold_index=0,\n        with_validation=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Reduce the number of dimensions\n-------------------------------\n\nUse TSNE to create a 2d representation of the results:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def extract_pathway_cols(df, pathway):\n    mask = df.columns.str.startswith(pathway + \"-\")\n    return df.loc[:, mask]\n\n\ndef compute_reduction(recons, pathways, reduction=\"tsne\"):\n    if reduction not in (\"tsne\", \"umap\"):\n        raise ValueError(\"Unexpected reduction type.\")\n    for key in pathways:\n        if reduction == \"tsne\":\n            reducer = TSNE(n_components=2)\n        else:\n            reducer = UMAP(n_components=2)\n        codes = extract_pathway_cols(recons.obsm[\"codes\"], key)\n        embedding = pd.DataFrame(\n            reducer.fit_transform(codes.values),\n            index=recons.obs_names,\n            columns=[\"{0}-0\".format(key), \"{0}-1\".format(key)])\n        yield embedding\n\n\noutput_file = os.path.join(checkpointdir, \"kang_recons.h5ad\")\ngenerated_pathways = [\n    \"REACTOME_INTERFERON_ALPHA_BETA_SIGNALING\",\n    \"REACTOME_CYTOKINE_SIGNALING_IN_IMMUNE_SYSTEM\",\n    \"REACTOME_TCR_SIGNALING\",\n    \"REACTOME_CELL_CYCLE\"]\nif not os.path.isfile(output_file):\n    y, X, _, loss, values = model.testing(\n        manager=manager,\n        with_logit=False,\n        predict=False,\n        concat_layer_outputs=\"z\")\n    print(y.shape)\n    global_recon = y[:, :membership_mask.shape[1]]\n    z = y[:, membership_mask.shape[1]:]\n    print(\" -- global recon:\", global_recon.shape)\n    print(\" -- z:\", z.shape)\n    recons = anndata.AnnData(\n        pd.DataFrame(\n            global_recon,\n            index=data.obs_names,\n            columns=data.var_names),\n        obs=data.obs,\n        varm=data.varm,\n    )\n    recons.obsm[\"codes\"] = pd.DataFrame(\n        z,\n        index=data.obs_names,\n        columns=model.model.latent_space_names())\n    recons.obsm[\"pathway_tsnes\"] = pd.concat(\n        compute_reduction(recons, generated_pathways, reduction=\"tsne\"),\n        axis=1)\n    recons.obsm[\"pathway_umaps\"] = pd.concat(\n        compute_reduction(recons, generated_pathways, reduction=\"umap\"),\n        axis=1)\n    recons.write(output_file)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Display\n--------\n\nDisplay the results & the ground truth:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def extract_pathway_cols(df, pathway):\n    mask = df.columns.str.startswith(pathway + \"-\")\n    return df.loc[:, mask]\n\n\ndef tab20(arg):\n    cmap = plt.get_cmap(\"tab20\")\n    return rgb2hex(cmap(arg))\n\n\ngenerated_recons = anndata.read(output_file)\nrecons = anndata.read(gtpath)\ncmap = {\n    \"CD4 T\": tab20(0),\n    \"CD8 T\": tab20(1),\n    \"CD14 Mono\": tab20(2),\n    \"CD16 Mono\": tab20(3),\n    \"B\": tab20(4),\n    \"DC\": tab20(6),\n    \"NK\": tab20(8),\n    \"T\": tab20(10)}\npathways = [\n    \"INTERFERON_ALPHA_BETA_SIGNALIN\",\n    \"CYTOKINE_SIGNALING_IN_IMMUNE_S\",\n    \"TCR_SIGNALING\",\n    \"CELL_CYCLE\"]\nfor _name, _reduction, _recons, _pathways in (\n        (\"GT\", \"tsne\", recons, pathways),\n        (\"GENERATED\", \"tsne\", generated_recons, generated_pathways),\n        (\"GENERATED\", \"umap\", generated_recons, generated_pathways)):\n    fig, axes = plt.subplots(2, len(pathways), figsize=(6 * len(_pathways), 8))\n    title = \"{0} pathway factorized latent space results ({1})\".format(\n        _name, _reduction.upper())\n    fig.suptitle(title, fontsize=15, y=0.99)\n    pairs = product([\"stimulated\", \"control\"], _pathways)\n    for ax, (active, key) in zip(axes.ravel(), pairs):\n        mask = (_recons.obs[\"condition\"] == active)\n        codes = extract_pathway_cols(_recons.obsm[\"pathway_tsnes\"], key)\n        # plot non-active condition\n        ax.scatter(*codes.loc[~mask].T.values, s=1, c=\"lightgrey\", alpha=0.1) \n        # plot active condition\n        ax.scatter(*codes.loc[mask].T.values,\n                   c=list(map(cmap.get, _recons.obs.loc[mask, \"cell_type\"])),\n                   s=1, alpha=0.5,)\n        key = key.replace(\"REACTOME_\", \"\")[:30]\n        ax.set_title(\"{0} {1}\".format(key, active), fontsize=10)\n        ax.axis(\"off\")\n    fig.legend(\n        handles=[Patch(color=c, label=l) for l,c in cmap.items()],\n        ncol=4, loc=(\"lower center\"), bbox_to_anchor=(0.5, 0.01),\n        fontsize=\"xx-large\", prop={\"size\": 10})\n    plt.tight_layout()\n    fig.subplots_adjust(bottom=.1)\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}