{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nUnsupervised clustering with GMVAE\n==================================\n\nCredit: A Grigis\n\nUnsupervised Gaussian Mixture Variational Autoencoder (GMVAE) on a synthetic\ndataset.\n\nGMVAE is an attempt to replicate the work described in this\n[blog](http://ruishu.io/2016/12/25/gmvae/) and inspired from this\n[paper](https://arxiv.org/abs/1611.02648).\n\nLet's begin with importing stuffs:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Imports\nimport os\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()\n\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom mpl_toolkits.mplot3d import Axes3D\nfrom matplotlib.ticker import NullFormatter\nfrom sklearn import manifold\nfrom sklearn.cluster import KMeans\nfrom sklearn.preprocessing import StandardScaler\nimport torch\nimport torch.nn as nn\nimport pynet\nfrom pynet import NetParameters\nfrom pynet.datasets import DataManager\nfrom pynet.interfaces import GMVAENetClassifier\nfrom pynet.utils import setup_logging"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Parameters\n----------\n\nDefine some global parameters that will be used to create and train the\nmodel:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "n_samples = 100\nn_classes = 3\nn_feats = 4\ntrue_lat_dims = 2\nfit_lat_dims = 5\nsnr = 10\nbatch_size = 10\nadam_lr = 2e-3\nepochs = 100\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nlosses = pynet.get_tools(tool_name=\"losses\")\nmetrics = pynet.get_tools(tool_name=\"metrics\")\nsetup_logging(level=\"info\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Synthetic dataset\n-----------------\n\nA Gaussian Linear Multi-Klass synthetic dataset is generated as\nfollows. The number of the latent dimensions used to generate the data can be\ncontrolled.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class GeneratorUniform(nn.Module):\n    \"\"\" Generate multiple sources (channels) of data through a linear\n    generative model:\n\n    z ~ N(mu,sigma)\n    for c_idx in n_channels:\n        x_ch = W_ch(c_idx)\n    where 'W_ch' is an arbitrary linear mapping z -> x_ch\n    \"\"\"\n    def __init__(self, lat_dim=2, n_channels=2, n_feats=5, seed=100):\n        super(GeneratorUniform, self).__init__()\n        self.lat_dim = lat_dim\n        self.n_channels = n_channels\n        self.n_feats = n_feats\n        self.seed = seed\n        np.random.seed(self.seed)\n        W = []\n        for c_idx in range(n_channels):\n            w_ = np.random.uniform(-1, 1, (self.n_feats, lat_dim))\n            u, s, vt = np.linalg.svd(w_, full_matrices=False)\n            w = (u if self.n_feats >= lat_dim else vt)\n            W.append(torch.nn.Linear(lat_dim, self.n_feats, bias=False))\n            W[c_idx].weight.data = torch.FloatTensor(w)\n        self.W = torch.nn.ModuleList(W)\n\n    def forward(self, z):\n        if isinstance(z, list):\n            return [self.forward(_) for _ in z]\n        if type(z) == np.ndarray:\n            z = torch.FloatTensor(z)\n        assert z.size(dim=1) == self.lat_dim\n        obs = []\n        for c_idx in range(self.n_channels):\n            x = self.W[c_idx](z)\n            obs.append(x.detach())\n        return obs\n\n\nclass SyntheticDataset(object):\n    def __init__(self, n_samples=500, lat_dim=2, n_feats=5, n_classes=2,\n                 generatorclass=GeneratorUniform, snr=1, train=True):\n        super(SyntheticDataset, self).__init__()\n        self.n_samples = n_samples\n        self.lat_dim = lat_dim\n        self.n_feats = n_feats\n        self.n_classes = n_classes\n        self.snr = snr\n        self.train = train\n        self.labels = []\n        self.z = []\n        self.x = []\n        seed = 7 if self.train else 14\n        np.random.seed(seed)\n        locs = np.random.uniform(-5, 5, (self.n_classes, ))\n        np.random.seed(seed)\n        scales = np.random.uniform(0, 2, (self.n_classes, ))\n        np.random.seed(seed)\n        for k_idx in range(self.n_classes):\n            self.z.append(\n                np.random.normal(loc=locs[k_idx], scale=scales[k_idx],\n                                 size=(self.n_samples, self.lat_dim)))\n            self.generator = generatorclass(\n                lat_dim=self.lat_dim, n_channels=1, n_feats=self.n_feats)\n            self.x.append(self.generator(self.z[-1])[0])\n            self.labels += [k_idx] * self.n_samples\n        self.data = np.concatenate(self.x, axis=0).astype(np.float32)\n        self.labels = np.asarray(self.labels)\n        _, self.data = preprocess_and_add_noise(self.data, snr=snr)\n\n\ndef preprocess_and_add_noise(x, snr, seed=0):\n    scalers = StandardScaler().fit(x)\n    x_std = scalers.transform(x)\n    np.random.seed(seed)\n    sigma_noise = np.sqrt(1. / snr)\n    x_std_noisy = x_std + sigma_noise * np.random.randn(*x_std.shape)\n    return x_std, x_std_noisy\n\n\n# Create dataset\nds_train = SyntheticDataset(\n    n_samples=n_samples,\n    lat_dim=true_lat_dims,\n    n_feats=n_feats,\n    n_classes=n_classes,\n    train=True,\n    snr=snr)\nds_val = SyntheticDataset(\n    n_samples=n_samples,\n    lat_dim=true_lat_dims,\n    n_feats=n_feats,\n    n_classes=n_classes,\n    train=False,\n    snr=snr)\nimage_datasets = {\n    \"train\": ds_train,\n    \"val\": ds_val}\nmanager = DataManager.from_numpy(\n    train_inputs=ds_train.data, train_outputs=None, train_labels=ds_train.labels,\n    validation_inputs=ds_val.data, validation_outputs=None,\n    validation_labels=ds_val.labels, batch_size=batch_size, sampler=\"random\",\n    add_input=True)\nprint(\"- datasets:\", image_datasets)\nprint(\"- shapes:\", ds_train.data.shape, ds_val.data.shape)\n\n\n# Display generated data\nmethod = manifold.TSNE(n_components=2, init=\"pca\", random_state=0)\ny_train = method.fit_transform(ds_train.data)\ny_val = method.fit_transform(ds_val.data)\nfig, axs = plt.subplots(nrows=3, ncols=2)\nfor cnt, (name, y, labels) in enumerate((\n        (\"train\", y_train, ds_train.labels),\n        (\"val\", y_val, ds_val.labels))):\n    colors = labels.astype(float)\n    colors /= colors.max()\n    axs[0, cnt].scatter(y[:, 0], y[:, 1], c=colors, cmap=plt.cm.Spectral)\n    axs[0, cnt].xaxis.set_major_formatter(NullFormatter())\n    axs[0, cnt].yaxis.set_major_formatter(NullFormatter())\n    axs[0, cnt].set_title(\"GT clustering ({0})\".format(name))\n    axs[0, cnt].axis(\"tight\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "ML clustering\n-------------\n\nAs a ground truth we performed a K-means clustering of the data.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "kmeans = KMeans(n_clusters=n_classes, random_state=0).fit(ds_train.data)\ntrain_labels = kmeans.labels_\ntrain_acc = losses[\"GMVAELoss\"].cluster_acc(train_labels, ds_train.labels)\nprint(\"-- K-Means ACC train\", train_acc)\nval_labels = kmeans.predict(ds_val.data)\nval_acc = losses[\"GMVAELoss\"].cluster_acc(val_labels, ds_val.labels)\nprint(\"-- K-Means ACC val\",val_acc)\n\nfor cnt, (name, y, labels, acc) in enumerate((\n        (\"train\", y_train, train_labels, train_acc),\n        (\"val\", y_val, val_labels, val_acc))):\n    colors = labels.astype(float)\n    colors /= colors.max()\n    axs[1, cnt].scatter(y[:, 0], y[:, 1], c=colors, cmap=plt.cm.Spectral)\n    axs[1, cnt].xaxis.set_major_formatter(NullFormatter())\n    axs[1, cnt].yaxis.set_major_formatter(NullFormatter())\n    axs[1, cnt].set_title(\n        \"K-means clustering ({0}-ACC:{1:.3f})\".format(name, acc))\n    axs[1, cnt].axis(\"tight\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Training\n--------\n\nWe'll create and train the model to optimize the losses using Adam\noptimizer.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "torch.manual_seed(42)\nparams = NetParameters(\n    input_dim=n_feats,\n    latent_dim=fit_lat_dims,\n    n_mix_components=n_classes,\n    sigma_min=0.001,\n    raw_sigma_bias=0.25,\n    dropout=0,\n    temperature=1,\n    gen_bias_init=0.)\nmodel = GMVAENetClassifier(\n    params,\n    optimizer_name=\"Adam\",\n    learning_rate=adam_lr,\n    loss=losses[\"GMVAELoss\"](),\n    use_cuda=(device.type != \"cpu\"))\nprint(\"- model:\", model)\n\nprint(\"- training...\")\ntrain_history, valid_history = model.training(\n    manager=manager,\n    nb_epochs=epochs,\n    checkpointdir=None,\n    fold_index=0,\n    with_validation=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Results\n-------\n\nLets now display the clustering results.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "net = model.model\nnet.eval()\nwith torch.no_grad():\n    p_x_given_z, dists = net(\n        torch.from_numpy(ds_train.data.astype(np.float32)).to(device))\nq_y_given_x = dists[\"q_y_given_x\"]\ntrain_labels = np.argmax(q_y_given_x.logits.detach().cpu().numpy(), axis=1)\ntrain_acc = losses[\"GMVAELoss\"].cluster_acc(\n    q_y_given_x.logits, ds_train.labels, is_logits=True)\nprint(\"-- GMVAE ACC train\", train_acc)\nwith torch.no_grad():\n    p_x_given_z, dists = net(\n            torch.from_numpy(ds_val.data.astype(np.float32)).to(device))\nq_y_given_x = dists[\"q_y_given_x\"]\nval_labels = np.argmax(q_y_given_x.logits.detach().cpu().numpy(), axis=1)\nval_acc = losses[\"GMVAELoss\"].cluster_acc(\n    q_y_given_x.logits, ds_val.labels, is_logits=True)\nprint(\"-- GMVAE ACC val\", val_acc)\n\nfor cnt, (name, y, labels, acc) in enumerate((\n        (\"train\", y_train, train_labels, train_acc),\n        (\"val\", y_val, val_labels, val_acc))):\n    colors = labels.astype(float)\n    colors /= colors.max()\n    axs[2, cnt].scatter(y[:, 0], y[:, 1], c=colors, cmap=plt.cm.Spectral)\n    axs[2, cnt].xaxis.set_major_formatter(NullFormatter())\n    axs[2, cnt].yaxis.set_major_formatter(NullFormatter())\n    axs[2, cnt].set_title(\n        \"GMVAE clustering ({0}-ACC:{1:.3f})\".format(name, acc))\n    axs[2, cnt].axis(\"tight\")\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}