{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\npynet: icosahedron UNet segmentation\n====================================\n\nCredit: A Grigis\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()\n\n# Imports\nimport collections\nimport logging\nimport pynet\nfrom pynet.datasets import DataManager\nfrom pynet.interfaces import SphericalUNetEncoder\nfrom pynet.utils import setup_logging\nfrom pynet.plotting import Board, update_board\nfrom pynet.models.spherical.sampling import icosahedron\nfrom pynet.plotting.surface import plot_trisurf\nimport numpy as np\nimport pandas as pd\nfrom scipy.stats import norm\nimport matplotlib.pyplot as plt\n\n\n# Global Parameters\nOUTDIR = \"/tmp/ico_unet\"\nBATCH_SIZE = 5\nN_EPOCHS = 5\nN_CLASSES = 2\nN_SAMPLES = 40\nICO_ORDER = 4\nSAMPLES = {\n    0: [(0, 1), (4, 2)],\n    1: [(2, 2), (2, 1)]}\nsetup_logging(level=\"debug\")\n\n\ndef gaussian_sdist(vertices, triangles, n_maps, scales):\n    \"\"\" Generate gaussian distance features maps.\n    \"\"\"\n    assert len(scales) == n_maps\n    locs = vertices[np.random.randint(0, len(vertices), n_maps)]\n    features = []\n    for loc, scale in zip(locs, scales):\n        dist = np.linalg.norm(vertices - loc, axis=1)\n        features.append(norm.pdf(dist, loc=0, scale=scale))\n    return np.asarray(features)\n\n\n# Load the data\nico_vertices, ico_triangles = icosahedron(order=ICO_ORDER)\nprint(ico_vertices.shape, ico_triangles.shape)\nprob = gaussian_sdist(ico_vertices, ico_triangles, n_maps=1, scales=[1])\nlabels = (prob[0] > 0.25).astype(int)\nfig, ax = plt.subplots(1, 1, subplot_kw={\n    \"projection\": \"3d\", \"aspect\": \"auto\"}, figsize=(10, 10))\ntri_texture = np.asarray(\n    [np.round(np.mean(labels[tri])) for tri in ico_triangles])\nplot_trisurf(fig, ax, ico_vertices, ico_triangles, tri_texture)\ndata = np.zeros((N_SAMPLES, N_CLASSES, len(labels)), dtype=float)\nfor klass in (0, 1):\n    k_indices = np.argwhere(labels == 0).squeeze()\n    for loc, scale in SAMPLES[klass]:\n        data[:, klass, k_indices] = np.random.normal(\n            loc=loc, scale=scale, size=len(k_indices))\nlabels = np.ones((N_SAMPLES, 1)) * labels\nprint(\"dataset: x {0} - y {1}\".format(data.shape, labels.shape))\n\n\n# Create data manager\nmanager = DataManager.from_numpy(\n    train_inputs=data, train_labels=labels, test_inputs=data,\n    test_labels=labels, batch_size=BATCH_SIZE)\n\n\n# Create model\nnet_params = pynet.NetParameters(\n    in_order=ICO_ORDER,\n    in_channels=2,\n    out_channels=N_CLASSES,\n    depth=3,\n    start_filts=32,\n    conv_mode=\"1ring\",\n    up_mode=\"transpose\",\n    cachedir=os.path.join(OUTDIR, \"cache\"))\nmodel = SphericalUNetEncoder(\n    net_params,\n    optimizer_name=\"SGD\",\n    learning_rate=0.1,\n    momentum=0.99,\n    weight_decay=10**-4,\n    loss_name=\"CrossEntropyLoss\",\n    use_cuda=True)\nmodel.board = Board(port=8097, host=\"http://localhost\", env=\"spherical_unet\")\nmodel.add_observer(\"after_epoch\", update_board)\n\n\n# Train model\ntest_history, train_history = model.training(\n    manager=manager,\n    nb_epochs=N_EPOCHS,\n    checkpointdir=None,\n    fold_index=0,\n    scheduler=None,\n    with_validation=False)\n\n\n# Test model\ny_pred, X, y_true, loss, values = model.testing(\n    manager=manager,\n    with_logit=True,\n    predict=True)\nprint(y_pred.shape, X.shape, y_true.shape)\n\n\n# Inspect results\nfig, ax = plt.subplots(1, 1, subplot_kw={\n    \"projection\": \"3d\", \"aspect\": \"auto\"}, figsize=(10, 10))\ntri_texture = np.asarray(\n    [np.round(np.mean(y_pred[:, tri])) for tri in ico_triangles])\nplot_trisurf(fig, ax, ico_vertices, ico_triangles, tri_texture)\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}