{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nMulti Channels VAE (MCVAE)\n==========================\n\nCredit: A Grigis & C. Ambroise\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Imports\nimport os\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()\n\nimport time\nimport copy\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom sklearn.preprocessing import StandardScaler\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.utils.data import Dataset\nimport pynet\nfrom pynet import NetParameters\nfrom pynet.datasets import DataManager\nfrom pynet.datasets.core import DataItem\nfrom pynet.interfaces import MCVAEEncoder\nfrom pynet.utils import setup_logging\n\n\n# Global parameters\nn_samples = 500\nn_channels = 3\nn_feats = 4\ntrue_lat_dims = 2\nfit_lat_dims = 5\nsnr = 10\nadam_lr = 2e-3\nepochs = 5000\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nlosses = pynet.get_tools(tool_name=\"losses\")\nsetup_logging(level=\"info\")\n\n\n# Create synthetic data\n\n\nclass GeneratorUniform(nn.Module):\n    \"\"\" Generate multiple sources (channels) of data through a linear\n    generative model:\n    z ~ N(0,I)\n    for c_idx in n_channels:\n        x_ch = W_ch(c_idx)\n    where 'W_ch' is an arbitrary linear mapping z -> x_ch\n    \"\"\"\n    def __init__(self, lat_dim=2, n_channels=2, n_feats=5, seed=100):\n        super(GeneratorUniform, self).__init__()\n        self.lat_dim = lat_dim\n        self.n_channels = n_channels\n        self.n_feats = n_feats\n        self.seed = seed\n        np.random.seed(self.seed)\n\n        W = []\n        for c_idx in range(n_channels):\n            w_ = np.random.uniform(-1, 1, (self.n_feats, lat_dim))\n            u, s, vt = np.linalg.svd(w_, full_matrices=False)\n            w = (u if self.n_feats >= lat_dim else vt)\n            W.append(torch.nn.Linear(lat_dim, self.n_feats, bias=False))\n            W[c_idx].weight.data = torch.FloatTensor(w)\n\n        self.W = torch.nn.ModuleList(W)\n\n    def forward(self, z):\n        if isinstance(z, list):\n            return [self.forward(_) for _ in z]\n        if type(z) == np.ndarray:\n            z = torch.FloatTensor(z)\n        assert z.size(1) == self.lat_dim\n        obs = []\n        for ch in range(self.n_channels):\n            x = self.W[ch](z)\n            obs.append(x.detach())\n        return obs\n\n\nclass SyntheticDataset(Dataset):\n    def __init__(self, n_samples=500, lat_dim=2, n_feats=5, n_channels=2,\n                 generatorclass=GeneratorUniform, snr=1, train=True):\n        super(SyntheticDataset, self).__init__()\n        self.n_samples = n_samples\n        self.lat_dim = lat_dim\n        self.n_feats = n_feats\n        self.n_channels = n_channels\n        self.snr = snr\n        self.train = train\n        seed = (7 if self.train is True else 14)\n        np.random.seed(seed)\n        self.z = np.random.normal(size=(self.n_samples, self.lat_dim))\n        self.generator = generatorclass(\n            lat_dim=self.lat_dim, n_channels=self.n_channels,\n            n_feats=self.n_feats)\n        self.x = self.generator(self.z)\n        self.X, self.X_noisy = preprocess_and_add_noise(self.x, snr=snr)\n        self.X = [np.expand_dims(x.astype(np.float32), axis=1) for x in self.X]\n\n    def __len__(self):\n        return self.n_samples\n\n    def __getitem__(self, item):\n        return DataItem(inputs=[x[item] for x in self.X], outputs=None,\n                        labels=None)\n\n    @property\n    def shape(self):\n        return (len(self), len(self.X))\n\ndef preprocess_and_add_noise(x, snr, seed=0):\n    if not isinstance(snr, list):\n        snr = [snr] * len(x)\n    scalers = [StandardScaler().fit(c_arr) for c_arr in x]\n    x_std = [scalers[c_idx].transform(x[c_idx]) for c_idx in range(len(x))]\n    # seed for reproducibility in training/testing based on prime number basis\n    seed = (seed + 3 * int(snr[0] + 1) + 5 * len(x) + 7 * x[0].shape[0] +\n            11 * x[0].shape[1])\n    np.random.seed(seed)\n    x_std_noisy = []\n    for c_idx, arr in enumerate(x_std):\n        sigma_noise = np.sqrt(1. / snr[c_idx])\n        x_std_noisy.append(arr + sigma_noise * np.random.randn(*arr.shape))\n    return x_std, x_std_noisy\n\n\n# Create dataset\nds_train = SyntheticDataset(\n    n_samples=n_samples,\n    lat_dim=true_lat_dims,\n    n_feats=n_feats,\n    n_channels=n_channels,\n    train=True,\n    snr=snr)\nds_val = SyntheticDataset(\n    n_samples=n_samples,\n    lat_dim=true_lat_dims,\n    n_feats=n_feats,\n    n_channels=n_channels,\n    train=False,\n    snr=snr)\nimage_datasets = {\n    \"train\": ds_train,\n    \"val\": ds_val}\nmanager = DataManager.from_dataset(\n    train_dataset=image_datasets[\"train\"],\n    validation_dataset=image_datasets[\"val\"],\n    batch_size=n_samples, sampler=\"random\", multi_bloc=True)\nprint(\"- datasets:\", image_datasets)\n\n\n# Create models\nmodels = {}\ntorch.manual_seed(42)\nparams = NetParameters(\n    latent_dim=fit_lat_dims,\n    n_channels=n_channels,\n    n_feats=[n_feats] * n_channels,\n    vae_model=\"dense\",\n    vae_kwargs={},\n    sparse=False)\nmodels[\"mcvae\"] = MCVAEEncoder(params,\n    optimizer_name=\"Adam\",\n    learning_rate=adam_lr,\n    loss=losses[\"MCVAELoss\"](n_channels, beta=1., sparse=False),\n    use_cuda=False)\ntorch.manual_seed(42)\nparams = NetParameters(\n    latent_dim=fit_lat_dims,\n    n_channels=n_channels,\n    n_feats=[n_feats] * n_channels,\n    vae_model=\"dense\",\n    vae_kwargs={},\n    sparse=True)\nmodels[\"smcvae\"] = MCVAEEncoder(params,\n    optimizer_name=\"Adam\",\n    learning_rate=adam_lr,\n    loss=losses[\"MCVAELoss\"](n_channels, beta=1., sparse=True),\n    use_cuda=False)\nprint(\"- models:\", models)\n\n\n# Fit models\nfor model_name, interface in models.items():\n    print(\"- training:\", model_name)\n    train_history, valid_history = interface.training(\n        manager=manager,\n        nb_epochs=epochs,\n        checkpointdir=None,\n        fold_index=0,\n        with_validation=True)\n\n\n# Display results\npred = {}  # Prediction\nz = {}     # Latent Space\ng = {}     # Generative Parameters\nx_hat = {}  # Reconstructed channels\nloaders = manager.get_dataloader(validation=True, fold_index=0)\ndataitem = next(iter(loaders.validation))\n\nfor model_name, interface in models.items():\n    model = interface.model\n    model.eval()\n    X = [x.to(interface.device) for x in dataitem.inputs]\n    print(\"--\", model_name)\n    print(\"-- X\", [x.size() for x in X])\n\n    with torch.no_grad():\n        q = model.encode(X)  # encoded distribution q(z|x)\n    print(\"-- encoded distribution q(z|x)\", [n for n in q])\n\n    z[model_name] = model.p_to_prediction(q)\n    print(\"-- z\", [e.shape for e in z[model_name]])\n\n    if model.sparse:\n        z[model_name] = model.apply_threshold(z[model_name], 0.2)\n    z[model_name] = np.array(z[model_name]).reshape(-1) # flatten\n    print(\"-- z\", z[model_name].shape)\n\n    g[model_name] = [\n        model.vae[c_idx].encode.w_mu.weight.detach().numpy()\n        for c_idx in range(n_channels)]\n    g[model_name] = np.array(g[model_name]).reshape(-1)  #flatten\n\n\n# With such a simple dataset, mcvae and sparse-mcvae gives the same results in\n# terms of latent space and generative parameters.\n# However, only with the sparse model is possible to easily identify the\n# important latent dimensions.\n\nplt.figure()\nplt.subplot(1,2,1)\nplt.hist([z[\"smcvae\"], z[\"mcvae\"]], bins=20, color=[\"k\", \"gray\"])\nplt.legend([\"Sparse\", \"Non sparse\"])\nplt.title(\"Latent dimensions distribution\")\nplt.ylabel(\"Count\")\nplt.xlabel(\"Value\")\nplt.subplot(1,2,2)\nplt.hist([g[\"smcvae\"], g[\"mcvae\"]], bins=20, color=[\"k\", \"gray\"])\nplt.legend([\"Sparse\", \"Non sparse\"])\nplt.title(r\"Generative parameters $\\mathbf{\\theta} = \\{\\mathbf{\\theta}_1 \"\n          r\"\\ldots \\mathbf{\\theta}_C\\}$\")\nplt.xlabel(\"Value\")\n\ndo = np.sort(models[\"smcvae\"].model.dropout.detach().numpy().reshape(-1))\nplt.figure()\nplt.bar(range(len(do)), do)\nplt.suptitle(\"Dropout probability of {0} fitted latent dimensions in Sparse \"\n             \"Model\".format(fit_lat_dims))\nplt.title(\"{0} true latent dimensions\".format(true_lat_dims))\n\nplt.show()\nprint(\"See you!\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}