{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\npynet: predicting autism\n========================\n\nCredit: A Grigis\n\nThis practice is based on\nhttps://www2.cs.sfu.ca/~hamarneh/ecopy/neuroimage2017.pdf.\n\nWe assessed BrainNetCNN ability to learn and discriminate between differing\nnetwork topologies using sets of synthetically generated networks. We first\nexamined the performance of BrainNetCNN on data with increasing levels of\nnoise and then compared BrainNetCNN to a fully-connected neural network with\nthe same number of model parameters.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()\nimport logging\nimport shutil\nimport pynet\nfrom pynet.datasets import DataManager, get_fetchers\nfrom pynet.utils import setup_logging\nfrom pynet.metrics import SKMetrics\nfrom pynet.plotting import Board, update_board\nfrom mne.viz import circular_layout, plot_connectivity_circle\nimport collections\nimport torch\nimport torch.nn as nn\nfrom torch.optim import lr_scheduler\nimport scipy\nfrom scipy.stats.stats import pearsonr\nimport numpy as np\nimport pandas as pd\nimport seaborn as sns\nimport matplotlib.pyplot as plt\n\nsetup_logging(level=\"info\")\nlogger = logging.getLogger(\"pynet\")\n\n# Load the data\noutdir = \"/tmp/graph_connectome\"\n(injury, x_train, y_train, x_test, y_test, x_valid,\n y_valid) = get_fetchers()[\"fetch_connectome\"](outdir)\nlabels = [str(idx) for idx in range(1, x_train.shape[-1] + 1)]\nfor name, (x, y) in ((\"train\", (x_train, y_train)),\n                     (\"test\", (x_test, y_test)),\n                     (\"validation\", (x_valid, y_valid))):\n    print(\"{0}: x {1} - y {2}\".format(name, x.shape, y.shape))\n\n# View the realistic base connectome and the injury signatures.\nplt.figure(figsize=(16, 4))\nplt.subplot(1, 3, 1)\nplt.imshow(injury.X_mn, interpolation=\"None\")\nplt.colorbar()\nplt.title(\"base connectome\")\nplt.subplot(1, 3, 2)\nplt.imshow(injury.sigs[0], interpolation=\"None\")\nplt.colorbar()\nplt.title(\"signature 1\")\nplt.subplot(1, 3, 3)\nplt.imshow(injury.sigs[1], interpolation=\"None\")\nplt.colorbar()\nplt.title(\"signature 2\")\n\n# Show example noisy training data that have the signatures applied.\n# It's not obvious to the human eye the subtle differences, but the cross\n# row and column above perturbed the below matrices with the y weights.\n# Show in the title how much each signature is weighted by.\nplt.figure(figsize=(16, 4))\nfor idx in range(3):\n    plt.subplot(1, 3, idx + 1)\n    plt.imshow(np.squeeze(x_train[idx]), interpolation=\"None\")\n    plt.colorbar()\n    plt.title(y_train[idx])\n\nmanager = DataManager.from_numpy(\n    train_inputs=x_train, train_labels=y_train,\n    validation_inputs=x_valid, validation_labels=y_valid,\n    test_inputs=x_test, test_labels=y_test,\n    batch_size=128, continuous_labels=True)\ninterfaces = pynet.get_interfaces()[\"graph\"]\nnet_params = pynet.NetParameters(\n    input_shape=(90, 90),\n    in_channels=1,\n    num_classes=2,\n    nb_e2e=32,\n    nb_e2n=64,\n    nb_n2g=30,\n    dropout=0.5,\n    leaky_alpha=0.1,\n    twice_e2e=False,\n    dense_sml=True)\nmy_loss = pynet.get_tools(tool_name=\"losses\")[\"MSELoss\"]()\nmodel = interfaces[\"BrainNetCNNGraph\"](\n    net_params,\n    optimizer_name=\"Adam\",\n    learning_rate=0.01,\n    weight_decay=0.0005,\n    loss_name=\"MSELoss\")\nmodel.board = Board(port=8097, host=\"http://localhost\", env=\"main\")\nmodel.add_observer(\"after_epoch\", update_board)\nscheduler = lr_scheduler.ReduceLROnPlateau(\n    optimizer=model.optimizer,\n    mode=\"min\",\n    factor=0.1,\n    patience=5,\n    verbose=True,\n    eps=1e-8)\ntest_history, train_history = model.training(\n    manager=manager,\n    nb_epochs=15,\n    checkpointdir=None,\n    fold_index=0,\n    scheduler=scheduler,\n    with_validation=True)\ny_pred, X, y_true, loss, values = model.testing(\n    manager=manager,\n    with_logit=True,\n    predict=False)\ny_pred_0, y_pred_1 = y_pred.T\ny_true_0, y_true_1 = y_true.T\nresult = pd.DataFrame.from_dict(collections.OrderedDict([\n    (\"pred_0\", y_pred_0),\n    (\"truth_0\", y_true_0),\n    (\"pred_1\", y_pred_1),\n    (\"truth_1\", y_true_1)]))\nprint(result)\n\n\ndef regression_metrics(pred_labels, true_labels):\n    \"\"\" Regression metrics as deefined is the tutorial.\n    \"\"\"\n    met = {}\n    met[\"mad\"] = np.mean((abs(pred_labels - true_labels)))\n    met[\"std_mad\"] = np.std(abs(pred_labels - true_labels))\n    # There's multiple labels.\n    if np.shape(np.squeeze(pred_labels).shape)[0] > 1:\n        n_labels = pred_labels.shape[1]\n        for idx in range(n_labels):\n            pred_values = pred_labels[:, idx]\n            actual_values = true_labels[:, idx]\n            r, p = pearsonr(pred_values, actual_values)\n            met[\"corr_\" + str(idx)] = r\n            met[\"p_\" + str(idx)] = p\n    # Only 1 label.\n    else:\n        r, p = pearsonr(pred_labels, true_labels)\n        met[\"corr_0\"] = r\n        met[\"p_0\"] = p\n    return met\n\n\nprint(\"E2E prediction results:\")\ntest_metrics_0 = regression_metrics(y_pred_0, y_true_0)\nprint(\"class 0: {0}\".format(test_metrics_0))\ntest_metrics_1 = regression_metrics(y_pred_1, y_true_1)\nprint(\"class 1: {0}\".format(test_metrics_1))\n\n# Saliency map is the gradient of the maximum score value with respect to\n# the input image.\nmodel.model.eval()\nX = torch.from_numpy(x_test)\nX.requires_grad_()\nscores = model.model(X)\nscores.backward(torch.ones(scores.shape, dtype=torch.float32))\nsaliency, _ = torch.max(X.grad.data.abs(), dim=1)\nsaliency = np.mean(saliency.numpy(), axis=0)\n\nhemi_size = len(labels) // 2\nnode_order = labels[:hemi_size]\nnode_order.extend(labels[hemi_size:][::-1])\nnode_angles = circular_layout(labels, node_order, start_pos=90,\n                              group_boundaries=[0, hemi_size])\nplot_connectivity_circle(saliency, labels, n_lines=300,\n                         node_angles=node_angles,\n                         title=\"Partial derivatives mapped on a circle plot\")\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}