Menu

Helper Module for Deep Learning.

Source code for pynet.plotting.image

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2019
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Common functions to display images.
"""

# Import
import logging
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision


# Global parameters
logger = logging.getLogger("pynet")


[docs]def plot_data(data, slice_axis=2, nb_samples=5, channel=0, labels=None, random=True, rgb=False, title=None): """ Plot an image associated data. Currently support 2D or 3D dataset of the form (samples, channels, dim). Parameters ---------- data: array (samples, channels, dim) the data to be displayed. slice_axis: int, default 2 the slice axis for 3D data. nb_samples: int, default 5 the number of samples to be displayed. channel: int, default 0 will select slices with data using the provided channel. labels: list of str, default None the data labels to be displayed. random: bool, default True select randomly 'nb_samples' data, otherwise the 'nb_samples' firsts. rgb: bool, default False if set expect three RGB channels. title: str, default None the figure title. """ # Check input parameters if data.ndim not in range(4, 6): raise ValueError("Unsupported data dimension.") nb_channels = data.shape[1] if rgb: if nb_channels != 3: raise ValueError("With RGB mode activated expect exactly 3 " "channels.") else: nb_channels = 1 # Reorganize 3D data if data.ndim == 5: indices = [0, 1, 2] assert slice_axis in indices indices.remove(slice_axis) indices = [slice_axis + 1, 0, indices[0] + 1, indices[1] + 1] slices = [img.transpose(indices) for img in data] data = np.concatenate(slices, axis=0) valid_indices = [ idx for idx in range(len(data)) if data[idx, channel].max() > 0] # Plot data on grid # plt.figure() # _data = torchvision.utils.make_grid(torch.from_numpy(data)) # _data = _data.numpy() # plt.imshow(np.transpose(_data, (1, 2, 0))) if random: indices = np.random.randint(0, len(valid_indices), nb_samples) else: if len(valid_indices) < nb_samples: nb_samples = len(valid_indices) indices = range(nb_samples) fig = plt.figure(figsize=(15, 7), dpi=200) fig.title = title for cnt1, ind in enumerate(indices): ind = valid_indices[ind] for cnt2 in range(nb_channels): if rgb: im = data[ind].transpose(1, 2, 0) cmap = None else: im = data[ind, cnt2] cmap = "gray" plt.subplot(nb_channels, nb_samples, nb_samples * cnt2 + cnt1 + 1) plt.axis("off") if cnt2 == 0 and labels is None: plt.title("Image " + str(ind)) elif cnt2 == 0: plt.title(labels[ind]) plt.imshow(im, cmap=cmap)
[docs]def plot_segmentation_data(data, mask, slice_axis=2, nb_samples=5): """ Display 'nb_samples' images and segmentation masks stored in data and mask. Currently support 2D or 3D dataset of the form (samples, channels, dim). Parameters ---------- data: array (samples, channels, dim) the data to be displayed. mask: array (samples, channels, dim) the mask data to be overlayed. slice_axis: int, default 2 the slice axis for 3D data. nb_samples: int, default 5 the number of samples to be displayed. """ # Check input parameters if data.ndim not in range(4, 6): raise ValueError("Unsupported data dimension.") # Reorganize 3D data if data.ndim == 5: indices = [0, 1, 2] assert slice_axis in indices indices.remove(slice_axis) indices = [slice_axis + 1, 0, indices[0] + 1, indices[1] + 1] slices = [img.transpose(indices) for img in data] data = np.concatenate(slices, axis=0) slices = [img.transpose(indices) for img in mask] mask = np.concatenate(slices, axis=0) mask = np.argmax(mask, axis=1) valid_indices = [idx for idx in range(len(mask)) if mask[idx].max() > 0] logger.debug(mask.shape, len(valid_indices)) # Plot data on grid indices = np.random.randint(0, len(valid_indices), nb_samples) plt.figure(figsize=(15, 7), dpi=200) for cnt, ind in enumerate(indices): ind = valid_indices[ind] im = data[ind, 0] plt.subplot(2, nb_samples, cnt + 1) plt.axis("off") # plt.title("Image " + str(ind)) plt.imshow(im, cmap="gray") mask_im = mask[ind] plt.subplot(2, nb_samples, cnt + 1 + nb_samples) plt.axis("off") plt.imshow(mask_im, cmap="jet") plt.imshow(im, cmap="gray", alpha=0.4)
[docs]def rescale_intensity(arr, in_range, out_range): """ Return arr after stretching or shrinking its intensity levels. Parameters ---------- arr: array input array. in_range, out_range: 2-tuple min and max intensity values of input and output arr. Returns ------- out: array array after rescaling its intensity. """ imin, imax = in_range omin, omax = out_range out = np.clip(arr, imin, imax) out = (out - imin) / float(imax - imin) return out * (omax - omin) + omin

Follow us

© 2019, pynet developers .
Inspired by AZMIND template.