{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\npynet: self supervised clustering\n=================================\n\nCredit: A Grigis\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()\n\n# Imports\nimport collections\nimport logging\nimport pynet\nfrom pynet.metrics import SKMetrics\nfrom pynet.datasets import DataManager\nfrom pynet.interfaces import DeepLearningInterface\nfrom pynet.interfaces import DeepClusterClassifier\nfrom pynet.models import BrainNetCNN\nfrom pynet.utils import setup_logging\nfrom pynet.plotting import Board, update_board\nfrom pynet.models.deepcluster import update_pseudo_labels\nimport torch\nimport torch.nn as nn\nfrom torch.utils.data.sampler import Sampler\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport seaborn as sns\nimport pandas as pd\nfrom sklearn.cluster import MiniBatchKMeans, KMeans\nfrom sklearn.metrics import classification_report\nfrom sklearn.metrics import roc_curve, auc\ntry:\n    import faiss\nexcept:\n    pass\n\n\n# Global Parameters\nOUTDIR = \"/tmp/graph_connectome\"\nBATCH_SIZE = 5\nN_EPOCHS = 20\nN_CLUSTERS = 2\nN_SAMPLES = 40\nAVOID_EMPTY_CLUSTERS = False\nUNIFORM_SAMPLING = True\nsetup_logging(level=\"info\")\n\n\n# Load the data\ndata = []\nlabels = []\nfor idx in range(N_CLUSTERS):\n\n    x_train = np.ones((N_SAMPLES, 1, 90, 90)) * idx\n    x_train += (np.random.rand(*x_train.shape) - 0.5) * 0.01\n    y_train = np.asarray([idx] * N_SAMPLES)\n    print(\"sub data: x {0} - y {1}\".format(x_train.shape, y_train.shape))\n    data.append(x_train)\n    labels.extend([idx] * len(x_train))\n\n    # Show example noisy training data that have the signatures applied.\n    # It's not obvious to the human eye the subtle differences, but the cross\n    # row and column above perturbed the below matrices with the y weights.\n    # Show in the title how much each signature is weighted by.\n    plt.figure(figsize=(16, 4))\n    for idx in range(3):\n        plt.subplot(1, 3, idx + 1)\n        plt.imshow(np.squeeze(x_train[idx]), interpolation=\"None\")\n        plt.colorbar()\n        plt.title(y_train[idx])\n\ndata = np.concatenate(data, axis=0)\nlabels = np.asarray(labels)\nprint(\"dataset: x {0} - y {1}\".format(data.shape, labels.shape))\n\n\nclass UniformLabelSampler(Sampler):\n    \"\"\" Samples elements uniformely accross pseudo labels.\n    \"\"\"\n    def __init__(self, data_array):\n        \"\"\" Init class.\n\n        Parameters\n        ----------\n        data_array: ArrayDataset\n            the train data array that contains the pseudo labels.\n        \"\"\"\n        self.n_samples = len(data_array)\n        self.data_array = data_array\n        self.indexes = self.generate_indexes_epoch()\n\n    def generate_indexes_epoch(self):\n        \"\"\" Generate sampling indexes.\n        \"\"\"\n        labels = self.data_array.labels\n        clusters_to_images = self.get_clusters(labels)\n        n_non_empty_clusters = len(clusters_to_images)\n        size_per_pseudolabel = int(self.n_samples / n_non_empty_clusters) + 1\n        res = np.array([])\n        for name, cluster_indexes in clusters_to_images.items():\n            indexes = np.random.choice(\n                cluster_indexes, size_per_pseudolabel,\n                replace=(len(cluster_indexes) <= size_per_pseudolabel))\n            res = np.concatenate((res, indexes))\n        np.random.shuffle(res)\n        res = list(res.astype(\"int\"))\n        if len(res) >= self.n_samples:\n            return res[:self.n_samples]\n        res += res[: (self.n_samples - len(res))]\n        return res\n\n    def get_clusters(self, labels):\n        \"\"\" Get indexes associated to each cluster.\n        \"\"\"\n        tally = collections.defaultdict(list)\n        for idx, item in enumerate(labels):\n            tally[item].append(idx)\n        return tally\n\n    def __iter__(self):\n        return iter(self.indexes)\n\n    def __len__(self):\n        return len(self.indexes)\n\n\n# Create data manager\nif UNIFORM_SAMPLING:\n    sampler = UniformLabelSampler\nelse:\n    sampler = \"random\"\nmanager = DataManager.from_numpy(\n    train_inputs=data, train_labels=np.zeros(labels.shape),\n    batch_size=BATCH_SIZE, sampler=sampler)\n\n\nclass FKmeans(object):\n    def __init__(self, n_clusters):\n        self.n_clusters = n_clusters\n\n    def fit(self, data):\n        n_data, d = data.shape\n        self.clus = faiss.Kmeans(d, self.n_clusters)\n        self.clus.seed = np.random.randint(1234)\n        self.clus.niter = 20\n        self.clus.max_points_per_centroid = 10000000\n        self.clus.train(data)\n\n    def predict(self, data):\n        _, I = self.clus.index.search(data, 1)\n        losses = self.clus.obj\n        print(\"k-means loss evolution: {0}\".format(losses))\n        return np.asarray([int(n[0]) for n in I])\n\n\ndef my_loss(x, y):\n    criterion = nn.CrossEntropyLoss()\n    print(\"  x: {0} - {1}\".format(x.shape, x.dtype))\n    print(torch.argmax(x, dim=1))\n    print(\"  y: {0} - {1}\".format(y.shape, y.dtype))\n    print(y)\n    return criterion(x, y)\n\n\n# Create model\ntrain_loader = manager.get_dataloader(train=True, fold_index=0).train\nif AVOID_EMPTY_CLUSTERS:\n    kmeans = FKmeans(n_clusters=N_CLUSTERS)\nelse:\n    kmeans = KMeans(\n        n_clusters=N_CLUSTERS,\n        random_state=None,\n        # verbose=100,\n        max_iter=20)\nnet = BrainNetCNN(\n    input_shape=(90, 90),\n    in_channels=1,\n    num_classes=N_CLUSTERS,\n    nb_e2e=32,\n    nb_e2n=64,\n    nb_n2g=30,\n    dropout=0,\n    leaky_alpha=0.1,\n    twice_e2e=False,\n    dense_sml=False)\nnet_params = pynet.NetParameters(\n    network=net,\n    clustering=kmeans,\n    data_loader=train_loader,\n    n_batchs=10,\n    pca_dim=6,\n    assignment_logfile=None,\n    use_cuda=False)\nmodel = DeepClusterClassifier(\n    net_params,\n    optimizer_name=\"SGD\",\n    learning_rate=0.001,\n    momentum=0.9,\n    weight_decay=10**-5,\n    # loss=my_loss)\n    loss_name=\"CrossEntropyLoss\")\nmodel.board = Board(port=8097, host=\"http://localhost\", env=\"deepcluster\")\nmodel.add_observer(\"before_epoch\", update_pseudo_labels)\nmodel.add_observer(\"after_epoch\", update_board)\n\n\n# Train model\ntest_history, train_history = model.training(\n    manager=manager,\n    nb_epochs=N_EPOCHS,\n    checkpointdir=None,\n    fold_index=0,\n    scheduler=None,\n    with_validation=False)\n\n\n# Test model\nmanager = DataManager.from_numpy(\n    test_inputs=data, test_labels=labels, batch_size=BATCH_SIZE)\ntest_model = DeepLearningInterface(\n    model=model.model.network,\n    optimizer_name=\"SGD\",\n    learning_rate=0.01,\n    momentum=0.9,\n    weight_decay=10**-5,\n    loss_name=\"CrossEntropyLoss\")\ny_pred, X, y_true, loss, values = test_model.testing(\n    manager=manager,\n    with_logit=True,\n    # logit_function=\"sigmoid\",\n    predict=False)\nprint(y_pred.shape, X.shape, y_true.shape)\n\n\n# Inspect results\nresult = pd.DataFrame.from_dict(collections.OrderedDict([\n    (\"pred\", (np.argmax(y_pred, axis=1)).astype(int)),\n    (\"truth\", y_true.squeeze()),\n    (\"prob_0\", y_pred[:, 0].squeeze()),\n    (\"prob_1\", y_pred[:, 1].squeeze())]))\nprint(result)\ny_pred = np.argmax(y_pred, axis=1)\nfig, ax = plt.subplots()\ncmap = plt.get_cmap(\"Blues\")\ncm = SKMetrics(\"confusion_matrix\", with_logit=False)(y_pred, y_true)\nsns.heatmap(cm, cmap=cmap, annot=True, fmt=\"g\", ax=ax)\nax.set_xlabel(\"predicted values\")\nax.set_ylabel(\"actual values\")\nprint(classification_report(y_true, y_pred >= 0.4))\nfpr, tpr, _ = roc_curve(y_true, y_pred)\nroc_auc = auc(fpr, tpr)\nplt.figure()\nplt.plot(fpr, tpr, color=\"darkorange\", lw=2,\n         label=\"ROC curve (area = %0.2f)\" % roc_auc)\nplt.plot([0, 1], [0, 1], color=\"navy\", lw=2, linestyle=\"--\")\nplt.xlim([0.0, 1.0])\nplt.ylim([0.0, 1.05])\nplt.xlabel(\"False Positive Rate\")\nplt.ylabel(\"True Positive Rate\")\nplt.title(\"Receiver operating characteristic\")\nplt.legend(loc=\"lower right\")\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}