Menu

Helper Module for Deep Learning.

Source code for pynet.plotting.board

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2020
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Common functions to display a dynamic board.
"""

# Import
import sys
import json
import logging
import numpy as np
import visdom
import torch
from subprocess import Popen, PIPE


# Global parameters
logger = logging.getLogger("pynet")


[docs]class Board(object): """ Define a dynamic board. It can be used to gather interesting plottings during a training. """
[docs] def __init__(self, port=8097, host="http://localhost", env="main", display_pred=False, prepare_pred=None): """ Initilaize the class. Parameters ---------- port: int, default 8097 the port on which the visdom server is launched. host: str, default 'http://localhost' the host on which visdom is launched. env: str, default 'main' the environment to be used. display_pred: bool, default False if set render the predicted images. prepare_pred: callable, defaultt None a function that transforms the predictions into a Nx1xXxY or Nx3xXxY array, with N the number of images. """ self.port = port self.host = host self.env = env self.display_pred = display_pred self.prepare_pred = prepare_pred self.plots = {} logger.debug("Create viewer on host {0} port {1}.".format(host, port)) self.viewer = visdom.Visdom( port=self.port, server=self.host, env=self.env) while len(logging.root.handlers) > 0: logging.root.removeHandler(logging.root.handlers[-1]) self.server = None if not self.viewer.check_connection(): self._create_visdom_server() current_data = json.loads(self.viewer.get_window_data()) for key in current_data: logger.debug("Closing plot {0}.".format(key)) self.viewer.close(win=key)
def __del__(self): """ Class destructor. """ if self.server is not None: logger.debug("Stoping visdom server.") self.server.kill() self.server.wait() def _create_visdom_server(self): """ It starts a new visdom server. """ current_python = sys.executable cmd = "{0} -m visdom.server -p {1}".format(current_python, self.port) logger.debug("Starting visdom server:\n{0}".format(cmd)) self.server = Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
[docs] def update_plots(self, data, epoch): """ Update/create plots from the input data. Parameters ---------- data: dict the name and new value of the plot to be updated. epoch: int the current epoch. """ logger.debug("Board data update:\n{0}".format(data)) # current_data = json.loads(self.viewer.get_window_data()) for key, val in data.items(): if key == "val_pred": if not self.display_pred: continue images = np.asarray(val) if self.prepare_pred is not None: images = self.prepare_pred(val) if images.ndim != 4: raise ValueError( "You must define a function that transforms the " "predictions into a Nx1xXxY or Nx3xXxY array, with N " "the number of images.") self.viewer.images( images, opts={ "title": key, "caption": "y_pred"}, win=key) else: self.viewer.line( X=np.asarray([epoch]), Y=np.asarray([val]), opts={ "title": key, "xlabel": "iterations", "ylabel": key}, update="append", win=key)
[docs]def update_board(signal): """ Callback to update visdom board visualizer. Parameters ---------- signal: SignalObject an object with the trained model 'object', the emitted signal 'signal', the epoch number 'epoch' and the fold index 'fold'. """ net = signal.object.model emitted_signal = signal.signal epoch = signal.epoch fold = signal.fold board = signal.object.board data = {} for key in signal.keys: if key in ("epoch", "fold"): continue value = getattr(signal, key) if key == "scheduler" and value is None: continue if key == "scheduler" and value is not None: if hasattr(value, "get_last_lr"): value = value.get_last_lr()[0] else: value = value._last_lr[0] key += "_lr" if isinstance(value, torch.Tensor): value = value.cpu().detach().numpy().tolist() data[key] = value board.update_plots(data, epoch)

Follow us

© 2019, pynet developers .
Inspired by AZMIND template.