{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\npynet dense fonctional brain networks extraction\n================================================\n\nCredit: A Grigis\n\nSpatiotemporal Attention Autoencoder (STAAE) for ADHD Classification,\nMICCAI, 2020.\nDEEP VARIATIONAL AUTOENCODER FOR MODELING FUNCTIONAL BRAIN NETWORKS AND ADHD\nIDENTIFICATION, ISBI 2020.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()\n\n\n# Import\nimport logging\nimport numpy as np\nimport random\nimport math\nimport time\nimport nibabel\nimport scipy.ndimage as ndimage\nfrom nilearn import datasets\nfrom nilearn import masking\nfrom nilearn.image.resampling import resample_to_img\nfrom nilearn.input_data import MultiNiftiMasker\nimport pandas as pd\nfrom sklearn.linear_model import Lasso\nimport pynet\nfrom pynet.datasets import DataManager\nfrom pynet.plotting import Board, update_board\nfrom pynet import NetParameters\nfrom pynet.interfaces import VAENetEncoder, STAAENetEncoder\nfrom pynet.utils import setup_logging\nfrom pynet.interfaces import DeepLearningInterface\nimport torch\nfrom torch import nn\n\n\n# Global parameters\nMODEL = \"STAAE\"  # \"DVAE\"\nDATADIR = \"/neurospin/nsap/research/stAAE/data\"\nWORKDIR = \"/neurospin/nsap/research/stAAE\"\nDATAFILE = os.path.join(WORKDIR, \"ADHD40.npy\")\nMASKFILE = os.path.join(WORKDIR, \"ADHD40_mask.nii.gz\")\nSEGFILE = \"./MNI152_T1_1mm_Brain_FAST_seg.nii.gz\"\nSTRUCTFILE = \"./MNI152_T1_2mm_strucseg_periph.nii.gz\"\nSEED = 1234\nBATCH_SIZE = 200\nrandom.seed(SEED)\nnp.random.seed(SEED)\nsetup_logging(level=\"info\")\nlogger = logging.getLogger(\"pynet\")\n\n\n# Prepare data\nadhd_dataset = datasets.fetch_adhd(n_subjects=40, data_dir=DATADIR)\nfunc_filenames = adhd_dataset.func\nprint(\"Functional nifti image: {0}...{1} ({2})\".format(\n    func_filenames[0], func_filenames[1], len(func_filenames)))\n\n# Build an EPI-based mask because we have no anatomical data\nif not os.path.isfile(MASKFILE):\n    target_img = nibabel.load(func_filenames[0])\n    target_mask = (target_img.get_data()[..., 0] != 0).astype(int)\n    template = nibabel.load(SEGFILE)\n    struct = nibabel.load(STRUCTFILE)\n    resampled_template = resample_to_img(\n        template, target_img, interpolation=\"nearest\")\n    resampled_struct = resample_to_img(\n        struct, target_img, interpolation=\"nearest\")\n    mask = (resampled_template.get_data() == 2).astype(float)\n    # mask = ndimage.gaussian_filter(mask, sigma=1.25)\n    mask = (mask >= 0.3).astype(int)\n    mask = mask & resampled_struct.get_data() & target_mask\n    mask_img = nibabel.Nifti1Image(mask, target_img.affine)\n    nibabel.save(mask_img, MASKFILE)\nelse:\n    mask_img = nibabel.load(MASKFILE)\n\n# Mask and preproc EPI data\n# Build an EPI-based mask because we have no anatomical data\n# Register atlas to select gray matter\nmasker = MultiNiftiMasker(\n    mask_img=mask_img,\n    standardize=True,\n    detrend=1,\n    smoothing_fwhm=6.)\nmasker.fit()\nif not os.path.isfile(DATAFILE):\n    iterator = np.concatenate(masker.transform(func_filenames), axis=0)\n    print(iterator.shape)\n    np.save(DATAFILE, iterator)\nelse:\n    iterator = np.load(DATAFILE)\n\n# Data iterator\niterator = np.expand_dims(iterator, axis=1)\nmanager = DataManager.from_numpy(train_inputs=iterator, batch_size=BATCH_SIZE,\n                                 add_input=True)\n\n# Create model\nif MODEL == \"DVAE\":\n    losses = pynet.get_tools(tool_name=\"losses\")\n    loss_klass = losses[\"BetaHLoss\"]\n    params = NetParameters(\n        input_channels=1,\n        input_dim=iterator.shape[-1],\n        conv_flts=None,\n        dense_hidden_dims=[256, 128, 64],\n        latent_dim=32)\n    interface = VAENetEncoder(\n        params,\n        optimizer_name=\"Adam\",\n        learning_rate=0.00001,\n        loss=loss_klass(use_mse=True, beta=1.),\n        use_cuda=False)\n    name = MODEL\nelse:\n    params = NetParameters(\n        input_dim=iterator.shape[-1])\n    interface = STAAENetEncoder(\n        params,\n        optimizer_name=\"Adam\",\n        learning_rate=0.001,\n        loss_name=\"MSELoss\",\n        use_cuda=True)\n    name = MODEL\n\n# Train model\ninterface.board = Board(\n    port=8097, host=\"http://localhost\", env=\"dvae\")\ninterface.add_observer(\"after_epoch\", update_board)\ntest_history, train_history = interface.training(\n    manager=manager,\n    nb_epochs=50,\n    checkpointdir=os.path.join(WORKDIR, \"checkpoint_\" + name),\n    fold_index=0,\n    with_validation=False)\n\n# Create test data\nmanager = DataManager.from_numpy(test_inputs=iterator, batch_size=BATCH_SIZE,\n                                 add_input=True)\n\n\ndef dummy_loss(*args, **kwargs):\n    return -1\n\n\n# Get latent parameters\ninterface.model.nodecoding = True\ninterface.loss = dummy_loss\ny_pred, X, y_true, loss, values = interface.testing(\n    manager=manager,\n    with_logit=False,\n    predict=False)\n\n\n# Compute encoder weights\ny = np.squeeze(y_pred)\niterator = np.squeeze(iterator)\nprint(y.shape, iterator.shape)\nclf = Lasso(alpha=0.01)\nclf.fit(y, iterator)\nprint(clf.coef_.shape)\n\n\ndef thresholding(components):\n    S = np.sqrt(np.sum(components ** 2, axis=1))\n    S[S == 0] = 1\n    components /= S[:, np.newaxis]\n\n    # Flip signs in each composant so that positive part is l1 larger\n    # than negative part. Empirically this yield more positive looking maps\n    # than with setting the max to be positive.\n    for component in components:\n        if np.sum(component > 0) < np.sum(component < 0):\n            component *= -1\n    return components\n\n\ndef plot_net(components):\n    from nilearn.plotting import plot_prob_atlas\n    from nilearn.image import iter_img\n    from nilearn.plotting import plot_stat_map, show\n\n    components_img = masker.inverse_transform(components)\n\n    # Plot all ICA components together\n    plot_prob_atlas(components_img, title=\"All ICA components\")\n\n    for i, cur_img in enumerate(iter_img(components_img)):\n        plot_stat_map(cur_img, display_mode=\"z\", title=\"IC %d\" % i,\n                      cut_coords=10, colorbar=False)\n\n    show()\n\n\n# Display thresholded FBNs\ncomponents = thresholding(clf.coef_.T)\nplot_net(components)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}