Helper Module for Deep Learning.
Source code for pynet.datasets.registration
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2019 - 2020
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
Module that provides functions to prepare registration dataset.
"""
# Imports
import os
import json
import h5py
import glob
import urllib
import shutil
import requests
import logging
import numpy as np
import scipy
import zipfile
import skimage
from collections import namedtuple
import pandas as pd
from pynet.datasets import Fetchers
# Global parameters
Item = namedtuple("Item", ["input_path", "output_path", "metadata_path",
"labels"])
URL = "https://docs.google.com/uc?export=download"
ID = "1rJtP9M1N3lSjNzJ5kIzRrrwPe1bWCfXB"
ATLAS = ("https://github.com/voxelmorph/voxelmorph/raw/master/data/"
"atlas_norm.npz")
logger = logging.getLogger("pynet")
[docs]def download_file_from_google_drive(destination):
session = requests.Session()
response = session.get(URL, params={"id": ID}, stream=True)
token = get_confirm_token(response)
if token:
params = {"id": ID, "confirm": token}
response = session.get(URL, params=params, stream=True)
save_response_content(response, destination)
[docs]def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
[docs]def save_response_content(response, destination):
CHUNK_SIZE = 32768
with open(destination, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
[docs]def wl_normalization(img, w=290, l=120):
img = skimage.exposure.rescale_intensity(
img, in_range=(l - w / 2, l + w / 2), out_range=(0, 255))
return img.astype(np.uint8)
[docs]def crop(arr, bound_l, bound_r, target_shape, order=1):
cropped = arr[bound_l[0]: bound_r[0], bound_l[1]: bound_r[1],
bound_l[2]: bound_r[2]]
return scipy.ndimage.zoom(
cropped, np.array(target_shape) / np.array(cropped.shape),
order=order)
[docs]def crop_mask(volume, segmentation, target_shape=(128, 128, 128)):
indices = np.array(np.nonzero(segmentation))
bound_r = np.max(indices, axis=-1)
bound_l = np.min(indices, axis=-1)
box_size = bound_r - bound_l + 1
padding = np.maximum((box_size * 0.1).astype(np.int32), 5)
bound_l = np.maximum(bound_l - padding, 0)
bound_r = np.minimum(bound_r + padding + 1, segmentation.shape)
return wl_normalization(crop(volume, bound_l, bound_r, target_shape))
[docs]@Fetchers.register
def fetch_registration(datasetdir):
""" Fetch/prepare the registration dataset for pynet.
Parameters
----------
datasetdir: str
the dataset destination folder.
Returns
-------
item: namedtuple
a named tuple containing 'input_path', 'output_path', and
'metadata_path'.
"""
logger.info("Loading registration dataset...")
if not os.path.isdir(datasetdir):
os.mkdir(datasetdir)
desc_path = os.path.join(datasetdir, "pynet_registration.tsv")
input_path = os.path.join(datasetdir, "pynet_registration_inputs.npy")
if not os.path.isfile(desc_path):
logger.debug("Processing {0}...".format(URL))
zfile = os.path.join(datasetdir, "brain_train.zip")
if not os.path.isfile(zfile):
download_file_from_google_drive(zfile)
else:
logger.info("ZIP already downloaded!")
afile = os.path.join(datasetdir, "atlas_norm.npz")
if not os.path.isfile(afile):
response = requests.get(ATLAS, stream=True)
with open(afile, "wb") as out_file:
shutil.copyfileobj(response.raw, out_file)
del response
else:
logger.info("ATLAS already downloaded!")
downloadir = os.path.join(datasetdir, "datasets")
if not os.path.isdir(downloadir):
with zipfile.ZipFile(zfile, "r") as zip_ref:
zip_ref.extractall(downloadir)
else:
logger.info("Archive already opened!")
# TODO fix that
atlas = np.load(afile)["vol"]
atlas *= 7000
logger.debug("Atlas {0}...".format(atlas.shape))
mask = atlas.astype(int)
mask[mask > 0] = 1
# atlas_norm = crop_mask(atlas, mask, target_shape=(128, 128, 128))
# logger.debug("Norm atlas {0}...".format(atlas_norm.shape))
try:
import nibabel
im = nibabel.Nifti1Image(atlas_norm, np.eye(4))
nibabel.save(im, os.path.join(datasetdir, "atlas_norm.nii.gz"))
except:
pass
files = glob.glob(os.path.join(downloadir, "*.h5"))
all_arrs = []
metadata = dict((key, []) for key in ("subjects", "centers", "studies",
"keys"))
for h5file in files:
logger.debug("Processing {0}...".format(h5file))
open_file = h5py.File(h5file, "r")
study = os.path.basename(h5file).replace(".h5", "")
for key in open_file.keys():
if key in metadata["keys"]:
raise ValueError(
"Key '{0}' appears multiple time.".format(key))
try:
center, sid = key.split("-")
except:
center = "na"
sid = key
logger.debug("Processing key {0} ({1}-{2})...".format(
key, sid, center))
data = open_file[key]["volume"]
metadata["subjects"].append(sid)
metadata["centers"].append(center)
metadata["studies"].append(study)
metadata["keys"].append(key)
all_arrs.append(np.array(data))
data = np.asarray(all_arrs)
try:
import nibabel
im = nibabel.Nifti1Image(
np.transpose(data, (1, 2, 3, 0)), np.eye(4))
nibabel.save(im, os.path.join(datasetdir, "data.nii.gz"))
except:
pass
atlas_norm = data[0]
atlas_norm = np.expand_dims(atlas_norm, axis=0)
atlas_norm = np.repeat(atlas_norm, len(data), axis=0)
data = np.expand_dims(data, axis=1)
atlas_norm = np.expand_dims(atlas_norm, axis=1)
logger.debug("Data: {0}".format(data.shape))
logger.debug("Atlas: {0}".format(atlas_norm.shape))
data = np.concatenate((data, atlas_norm), axis=1)
logger.debug("Input: {0}-{1}".format(data.shape, data.dtype))
np.save(input_path, data)
df = pd.DataFrame.from_dict(metadata)
df.to_csv(desc_path, sep="\t", index=False)
return Item(input_path=input_path, output_path=None,
metadata_path=desc_path, labels=None)
Follow us
© 2019, pynet developers .
Inspired by AZMIND template.
Inspired by AZMIND template.