{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nBeta VAE disentangling\n======================\n\nCredit: A Grigis\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Imports\nimport os\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()\nimport glob\nfrom PIL import Image\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport matplotlib.colors as mcolors\nimport torch\nfrom torch.distributions import Normal, kl_divergence\nfrom pynet import NetParameters\nfrom pynet.datasets import DataManager\nfrom pynet.datasets.dsprites import DSprites\nfrom pynet.interfaces import VAENetEncoder\nfrom pynet.plotting import Board, update_board\nfrom pynet.losses import get_vae_loss\nfrom pynet.models.vae.utils import (\n    reconstruct_traverse, make_mosaic_img, add_labels)\n\n\n# Global parameters\nWDIR = \"/tmp/beta_vae_disentangling\"\nBATCH_SIZE = 64\nN_EPOCHS = 30\nADAM_LR = 5e-4\nDEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nDISPLAY = False\n\n# Load the data\ndataset = DSprites(WDIR)\nmanager = DataManager.from_dataset(\n    train_dataset=dataset, batch_size=BATCH_SIZE, sampler=\"random\")\n\n\n# Test different losses\n\nloss_params = {\n    \"betah\": {\"beta\": 4, \"steps_anneal\": 0, \"use_mse\": True},\n    \"betab\": {\"C_init\": 0.5, \"C_fin\": 25, \"gamma\": 100,\n              \"steps_anneal\": 100000, \"use_mse\": True},\n    \"btcvae\": {\"dataset_size\": len(dataset), \"alpha\": 1, \"beta\": 1, \"gamma\": 6,\n               \"is_mss\": True, \"steps_anneal\": 0, \"use_mse\": True}\n}\n\n\ndef plot_losses(cache, filename):\n    if \"kl\" not in cache or \"ll\" not in cache:\n        return\n    ll = np.asarray(cache[\"ll\"]).squeeze()\n    kl = np.asarray(cache[\"kl\"]).squeeze()\n    fig, axs = plt.subplots(nrows=1, ncols=2)\n    colors = list(mcolors.TABLEAU_COLORS.keys())\n    for idx, dim_kl in enumerate(kl.T):\n        axs[0].plot(\n            dim_kl, color=colors[idx], label=\"dim{0}\".format(idx + 1))\n        axs[0].set_xlabel(\"Training iterations\")\n        axs[0].set_ylabel(\"KL\")\n        axs[1].plot(\n            ll, dim_kl, color=colors[idx], label=\"dim{0}\".format(idx + 1))\n        axs[1].set_xlabel(\"Log Likelihood\")\n        axs[1].set_ylabel(\"KL\")\n    plt.legend(loc=\"upper left\")\n    plt.tight_layout()\n    plt.savefig(filename)\n\n\ndef plot_reconstructions(model, data, checkpointdir, filename=None):\n    weights_files = glob.glob(os.path.join(checkpointdir, \"*.pth\"))\n    n_plots = len(weights_files)\n    original = data.cpu().numpy()\n    original = np.expand_dims(original, axis=0)\n    stages = [original]\n    labels = [\"orig\"]\n    for idx, path in enumerate(sorted(weights_files)):\n        checkpoint = torch.load(path)\n        model.load_state_dict(checkpoint[\"model\"])\n        reconstruction = model.reconstruct(data, sample=False)\n        reconstruction = np.expand_dims(reconstruction, axis=0)\n        stages.append(reconstruction)\n        labels.append(\"rec stage {0}\".format(idx + 1))\n    concatenated = np.concatenate(stages, axis=0)\n    mosaic = make_mosaic_img(concatenated)\n    concatenated = Image.fromarray(mosaic)\n    concatenated = add_labels(concatenated, labels)\n    if filename is not None:\n        concatenated.save(filename)\n    return concatenated\n\n\nfor loss_name in (\"betah\", \"betab\", \"btcvae\"):\n\n    # Train the model\n    checkpointdir = os.path.join(WDIR, \"checkpoints\", loss_name)\n    if not os.path.isdir(checkpointdir):\n        os.makedirs(checkpointdir)\n    weights_filename = os.path.join(\n        checkpointdir, \"model_0_epoch_{0}.pth\".format(N_EPOCHS))\n    params = NetParameters(\n        input_channels=1,\n        input_dim=DSprites.img_size,\n        conv_flts=[32, 32, 32, 32],\n        dense_hidden_dims=[256, 256],\n        latent_dim=10,\n        noise_out_logvar=-3,\n        noise_fixed=False,\n        act_func=None,\n        dropout=0,\n        sparse=False)\n    loss = get_vae_loss(loss_name=loss_name, **loss_params[loss_name])\n    if os.path.isfile(weights_filename):\n        vae = VAENetEncoder(\n            params,\n            optimizer_name=\"Adam\",\n            learning_rate=ADAM_LR,\n            loss=loss,\n            use_cuda=(DEVICE.type == \"cuda\"),\n            pretrained=weights_filename)\n    else:\n        vae = VAEEncoder(\n            params,\n            optimizer_name=\"Adam\",\n            learning_rate=ADAM_LR,\n            loss=loss,\n            use_cuda=(DEVICE.type == \"cuda\"))\n        vae.board = Board(\n            port=8097, host=\"http://localhost\", env=\"beta-vae\")\n        vae.add_observer(\"after_epoch\", update_board)\n        train_history, valid_history = vae.training(\n            manager=manager,\n            nb_epochs=(N_EPOCHS + 1),\n            checkpointdir=checkpointdir,\n            fold_index=0,\n            with_validation=False,\n            save_after_epochs=5)\n        plot_losses(vae.loss.cache,\n                    os.path.join(WDIR, \"loss_{0}.png\".format(loss_name)))\n    print(vae.model)\n\n    # Display results\n    index = np.arange(len(dataset))\n    np.random.shuffle(index)\n    data = torch.unsqueeze(torch.from_numpy(\n        dataset.imgs[index][:100].astype(np.float32)), dim=1).to(DEVICE)\n    vae.model.eval()\n    name = \"traverse_posteriror_{0}\".format(loss_name)\n    filename = os.path.join(WDIR, \"{0}.png\".format(name))\n    mosaic_traverse = reconstruct_traverse(\n        vae.model, data, n_per_latent=8, n_latents=None, is_posterior=True,\n        filename=filename)\n    filename = os.path.join(\n        WDIR, \"reconstruction_stages_{0}.png\".format(loss_name))\n    plot_reconstructions(vae.model, data[:8], checkpointdir, filename=filename)\n\n    if DISPLAY:\n        plt.figure()\n        plt.imshow(np.asarray(mosaic_traverse))\n        plt.title(name)\n        plt.axis(\"off\")\n\nif DISPLAY:\n    plt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}