Menu

Helper Module for Deep Learning.

Source code for pynet.plotting.network

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2019
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Module that provides tools to display a graph.
"""

# System import
import sys
import os
import time
from pprint import pprint
import tempfile
import weakref
import operator
import tempfile

# Third party import
import torch
import hiddenlayer as hl
from PySide2 import QtCore, QtGui, QtWidgets
from torchviz import make_dot

# Module import
from .graph import Graph, GraphNode
from pynet.plotting.colors import *


[docs]def plot_net_rescue(model, shape, outfileroot=None): """ Save a PNG file containing the network graph representation. Parameters ---------- model: Net the network model. shape: list of int the shape of a classical input batch dataset. outfileroot: str, default None the file path without extension. Returns ------- outfile: str the path to the generated PNG. """ x = torch.randn(shape) graph = make_dot(model(x), params=dict(model.named_parameters())) graph.format = "png" if outfileroot is None: dirpath = tempfile.mkdtemp() basename = "pynet_graph" else: dirpath = os.path.dirname(outfileroot) basename = os.path.basename(outfileroot) graph.render(directory=dirpath, filename=basename, view=True) return os.path.join(dirpath, basename + ".png")
[docs]def plot_net(model, shape, static=True, outfileroot=None): """ Save a PDF file containing the network graph representation. Sometimes the 'get_trace_graph' pytorch function fails: use the 'plot_net_rescue' function insteed. Parameters ---------- model: Net the network model. shape: list of int the shape of a classical input batch dataset. static: bool, default True create a static or dynamic view. outfileroot: str, default None the file path without extension to generate PDF. Returns ------- outfile: str the path to the generated PDF. """ # Create application app = QtWidgets.QApplication.instance() if app is None: app = QtWidgets.QApplication(sys.argv) # Create view hl_graph = hl.build_graph(model, torch.zeros(shape)) hl_graph.theme = hl.graph.THEMES["blue"].copy() outfile = None if outfileroot is not None: hl_graph.save(outfileroot) outfile = outfileroot + ".pdf" if not os.path.isfile(outfile): raise ValueError("'{0}' has not been generated.".format(outfile)) if static: def draw(widget, surface): page.render(surface) with tempfile.TemporaryDirectory() as tmpdir: tmpfileroot = os.path.join(tmpdir, "graph") hl_graph.save(tmpfileroot, format="png") tmpfile = tmpfileroot + ".png" widget = PDFView(tmpfile) view = QtWidgets.QScrollArea() view.setWidgetResizable(True) view.setWidget(widget) else: graph = Graph() nodes_map = {} cnt = 1 for key, node in hl_graph.nodes.items(): label = node.title if node.caption: label += node.caption if node.repeat: label += str(node.repeat) nodes_map[key] = "{0}-{1}".format(cnt, label) cnt += 1 for key, node in hl_graph.nodes.items(): graph.add_node(GraphNode(str(nodes_map[key]), node)) for key1, key2, label in hl_graph.edges: if isinstance(label, (list, tuple)): label = "x".join([str(l or "?") for l in label]) graph.add_link(str(nodes_map[key1]), str(nodes_map[key2])) view = GraphView(graph) # Display view.show() app.exec_() return outfile
class PDFView(QtWidgets.QWidget): """ A widget to visualize a PDF graph. """ def __init__(self, path): """ Initialize the PDFView class """ super(PDFView, self).__init__() self.path = path layout = QtWidgets.QVBoxLayout(self) self.label = QtWidgets.QLabel() layout.addWidget(self.label) self.pixmap = QtGui.QPixmap(self.path) self.label.setPixmap(self.pixmap) class Control(QtWidgets.QGraphicsPolygonItem): """ Create a glyph for each control connection. """ def __init__(self, name, height, width, optional, parent=None): """ Initilaize the Control class. Parameters ---------- name: str the control name. height, width: int the control size. optional: bool option to color the glyph. """ # Inheritance super(Control, self).__init__(parent) # Class parameters self.name = name self.optional = optional color = self._color(optional) self.brush = QtGui.QBrush(QtCore.Qt.SolidPattern) self.brush.setColor(color) # Set graphic item properties self.setAcceptedMouseButtons(QtCore.Qt.LeftButton) # Define the widget polygon = QtGui.QPolygonF([ QtCore.QPointF(0, 0), QtCore.QPointF(width, (height - 5) / 2.0), QtCore.QPointF(0, height - 5)]) self.setPen(QtGui.QPen(QtCore.Qt.NoPen)) self.setPolygon(polygon) self.setBrush(self.brush) self.setZValue(3) def _color(self, optional): """ Define the color of a control glyph depending on its status. Parameters ---------- optional: bool (mandatory) option to color the glyph. Returns ------- color: QColor the glyph color. """ if optional: color = QtCore.Qt.darkGreen else: color = QtCore.Qt.black return color def get_control_point(self): """ Give the relative location of the control glyph in the parent widget. Returns ------- position: QPointF the control glyph position. """ point = QtCore.QPointF( self.boundingRect().size().width() / 2.0, self.boundingRect().size().height() / 2.0) return self.mapToParent(point) class Node(QtWidgets.QGraphicsItem): """ A box node. """ _colors = { "default": (RED_1, RED_2, LIGHT_RED_1, LIGHT_RED_2), "choice1": (SAND_1, SAND_2, LIGHT_SAND_1, LIGHT_SAND_2), "choice2": (DEEP_PURPLE_1, DEEP_PURPLE_2, PURPLE_1, PURPLE_2), "choice3": (BLUE_1, BLUE_2, LIGHT_BLUE_1, LIGHT_BLUE_2) } def __init__(self, name, inputs, outputs, active=True, style=None, graph=None, parent=None): """ Initilaize the Node class. Parameters ---------- name: string a name for the box node. inputs: list of str the box input controls. If None no input will be created. outputs: list of str the box output controls. If None no output will be created. active: bool, default True) a special color will be applied on the node rendering depending of this parameter. style: string, default None the style that will be applied to tune the box rendering. graph: Graph, default None a sub-graph item. """ # Inheritance super(Node, self).__init__(parent) # Class parameters self.style = style or "default" self.name = name self.graph = graph self.inputs = inputs or [] self.outputs = outputs or [] self.active = active self.input_controls = {} self.output_controls = {} self.embedded_box = None # Set graphic item properties self.setFlag(QtGui.QGraphicsItem.ItemIsMovable) self.setAcceptedMouseButtons( QtCore.Qt.LeftButton | QtCore.Qt.RightButton | QtCore.Qt.MiddleButton) # Define rendering colors bgd_color_indices = [2, 3] if self.active: bgd_color_indices = [0, 1] self.background_brush = self._get_brush( *operator.itemgetter(*bgd_color_indices)(self._colors[self.style])) self.title_brush = self._get_brush( *operator.itemgetter(2, 3)(self._colors[self.style])) # Construct the node self._build() def get_title(self): """ Create a title for the node. """ return self.name def _build(self, margin=5): """ Create a node reprensenting a box. Parameters ---------- margin: int (optional, default 5) the default margin. """ # Create a title for the node self.title = QtGui.QGraphicsTextItem(self.get_title(), self) font = self.title.font() font.setWeight(QtGui.QFont.Bold) self.title.setFont(font) self.title.setPos(margin, margin) self.title.setZValue(2) self.title.setParentItem(self) # Define the default control position control_position = ( margin + margin + self.title.boundingRect().size().height()) # Create the input controls for input_name in self.inputs: # Create the control representation control_glyph, control_text = self._create_control( input_name, control_position, is_output=False, margin=margin) # Update the class parameters self.input_controls[input_name] = (control_glyph, control_text) # Update the next control position control_position += control_text.boundingRect().size().height() # Create the output controls for output_name in self.outputs: # Create the control representation control_glyph, control_text = self._create_control( output_name, control_position, is_output=True, margin=margin) # Update the class parameters self.output_controls[output_name] = (control_glyph, control_text) # Update the next control position control_position += control_text.boundingRect().size().height() # Define the box node self.box = QtGui.QGraphicsRectItem(self) self.box.setBrush(self.background_brush) self.box.setPen(QtGui.QPen(QtCore.Qt.NoPen)) self.box.setZValue(-1) self.box.setParentItem(self) self.box.setRect(self.contentsRect()) self.box_title = QtGui.QGraphicsRectItem(self) rect = self.title.mapRectToParent(self.title.boundingRect()) brect = self.contentsRect() brect.setWidth(brect.right() - margin) rect.setWidth(brect.width()) self.box_title.setRect(rect) self.box_title.setBrush(self.title_brush) self.box_title.setPen(QtGui.QPen(QtCore.Qt.NoPen)) self.box_title.setParentItem(self) def _create_control(self, control_name, control_position, is_output=False, control_width=12, margin=5): """ Create a control representation: small glyph and control name. Parameters ---------- control_name: str (mandatory) the name of the control to render. control_position: int (mandatory) the position (height) of the control to render. control_name: bool (optional, default False) an input control glyph is diplayed on the left while an output control glyph is displayed on the right. control_width: int (optional, default 12) the default size of the control glyph. margin: int (optional, default 5) the default margin. Returns ------- control_text: QGraphicsTextItem the control text item. control_glyph: Control the associated control glyph item. """ # Detect if the control is optional is_optional = False # Create the control representation control_text = QtGui.QGraphicsTextItem(self) control_text.setHtml(control_name) control_name = "{0}:{1}".format(self.name, control_name) control_glyph = Control( control_name, control_text.boundingRect().size().height(), control_width, optional=is_optional, parent=self) control_text.setZValue(2) control_glyph_width = control_glyph.boundingRect().size().width() control_title_width = self.title.boundingRect().size().width() control_text.setPos(control_glyph_width + margin, control_position) if is_output: control_glyph.setPos( control_title_width - control_glyph_width, control_position) else: control_glyph.setPos(margin, control_position) control_text.setParentItem(self) control_glyph.setParentItem(self) return control_glyph, control_text def _get_brush(self, color1, color2): """ Create a brush that has a style, a color, a gradient and a texture. Parameters ---------- color1, color2: QtGui.QColor (mandatory) edge box colors used to define the gradient. """ gradient = QtGui.QLinearGradient(0, 0, 0, 50) gradient.setColorAt(0, color1) gradient.setColorAt(1, color2) return QtGui.QBrush(gradient) def contentsRect(self): """ Returns the area inside the widget's margins. Returns ------- brect: QRectF the bounding rectangle (left, top, right, bottom). """ first = True excluded = [] for name in ("box", "box_title"): if hasattr(self, name): excluded.append(getattr(self, name)) for child in self.childItems(): if not child.isVisible() or child in excluded: continue item_rect = self.mapRectFromItem(child, child.boundingRect()) if first: first = False brect = item_rect else: if item_rect.left() < brect.left(): brect.setLeft(item_rect.left()) if item_rect.top() < brect.top(): brect.setTop(item_rect.top()) if item_rect.right() > brect.right(): brect.setRight(item_rect.right()) if item_rect.bottom() > brect.bottom(): brect.setBottom(item_rect.bottom()) return brect def boundingRect(self): """ Returns the bounding rectangle of the given text as it will appear when drawn inside the rectangle beginning at the point (x , y ) with width w and height h. Returns ------- brect: QRectF the bounding rectangle (x, y, w, h). """ brect = self.contentsRect() brect.setRight(brect.right()) brect.setBottom(brect.bottom()) return brect def paint(self, painter, option, widget=None): pass def mouseDoubleClickEvent(self, event): """ If a sub-graph is available emit a 'subgraph_clicked' signal. """ if self.graph is not None: self.scene().subgraph_clicked.emit(self.name, self.graph, event.modifiers()) event.accept() else: event.ignore() def add_subgraph_view(self, graph, margin=5): """ Display the a sub-graph box in a node. Parameters ---------- graph: Graph the sub-graph box to display. """ # Create a embedded proxy view if self.embedded_box is None: view = GraphView(graph) proxy_view = EmbeddedSubGraphItem(view) view._graphics_item = weakref.proxy(proxy_view) proxy_view.setParentItem(self) posx = margin + self.box.boundingRect().width() proxy_view.setPos(posx, margin) self.embedded_box = proxy_view # Change visibility property of the embedded proxy view else: if self.embedded_box.isVisible(): self.embedded_box.hide() else: self.embedded_box.show() class EmbeddedSubGraphItem(QtWidgets.QGraphicsProxyWidget): """ QGraphicsItem containing a sub-graph box view. """ def __init__(self, sub_graph_view): """ Initialize the EmbeddedSubGraphItem. Parameters ---------- sub_graph_view: GraphView the sub-graph view. """ # Inheritance super(EmbeddedSubGraphItem, self).__init__() # Define rendering options sub_graph_view.setHorizontalScrollBarPolicy( QtCore.Qt.ScrollBarAlwaysOn) sub_graph_view.setVerticalScrollBarPolicy( QtCore.Qt.ScrollBarAlwaysOn) # sub_graph_view.setFixedSize(400, 600) # Add the sub-graph widget self.setWidget(sub_graph_view) # sub_graph_view.setSizePolicy( # QtGui.QSizePolicy.Preferred, QtGui.QSizePolicy.Preferred) # self.setSizePolicy( # QtGui.QSizePolicy.Preferred, QtGui.QSizePolicy.Preferred) class Link(QtWidgets.QGraphicsPathItem): """ A link between boxes. """ def __init__(self, src_position, dest_position, parent=None): """ Initilaize the Link class. Parameters ---------- src_position: QPointF (mandatory) the source control glyph position. dest_position: QPointF (mandatory) the destination control glyph position. """ # Inheritance super(Link, self).__init__(parent) # Define the color rendering pen = QtGui.QPen() pen.setWidth(2) pen.setBrush(RED_2) pen.setCapStyle(QtCore.Qt.RoundCap) pen.setJoinStyle(QtCore.Qt.RoundJoin) self.setPen(pen) # Draw the link path = QtGui.QPainterPath() path.moveTo(src_position.x(), src_position.y()) path.cubicTo(src_position.x() + 100, src_position.y(), dest_position.x() - 100, dest_position.y(), dest_position.x(), dest_position.y()) self.setPath(path) self.setZValue(0.5) def update(self, src_position, dest_position): """ Update the link extreme positions. Parameters ---------- src_position: QPointF (mandatory) the source control glyph position. dest_position: QPointF (mandatory) the destination control glyph position. """ path = QtGui.QPainterPath() path.moveTo(src_position.x(), src_position.y()) path.cubicTo(src_position.x() + 100, src_position.y(), dest_position.x() - 100, dest_position.y(), dest_position.x(), dest_position.y()) self.setPath(path) class GraphScene(QtWidgets.QGraphicsScene): """ Define a scene representing a graph. """ # Signal emitted when a sub graph has to be open subgraph_clicked = QtCore.Signal(str, Graph, QtCore.Qt.KeyboardModifiers) def __init__(self, graph, parent=None): """ Initilaize the GraphScene class. Parameters ---------- graph: Graph graph to be displayed. parent: QWidget, default None) parent widget. """ # Inheritance super(GraphScene, self).__init__(parent) # Class parameters self.graph = graph self.gnodes = {} self.glinks = {} self.gpositions = {} # Add event to upadate links self.changed.connect(self.update_links) def update_links(self): """ Update the node positions and associated links. """ for node in self.items(): if isinstance(node, Node): self.gpositions[node.name] = node.pos() for linkdesc, link in self.glinks.items(): # Parse the link description src_control, dest_control = self.parse_link_description(linkdesc) # Get the source and destination nodes/controls src_gnode = self.gnodes[src_control[0]] dest_gnode = self.gnodes[dest_control[0]] src_gcontrol = src_control[1] dest_gcontrol = dest_control[1] # Update the current link src_control_glyph = src_gnode.output_controls[src_gcontrol][0] dest_control_glyph = dest_gnode.input_controls[dest_gcontrol][0] link.update( src_gnode.mapToScene(src_control_glyph.get_control_point()), dest_gnode.mapToScene(dest_control_glyph.get_control_point())) def draw(self): """ Draw the scene representing the graph. """ # Add the graph graph for box_name, box in self.graph._nodes.items(): # Define the box type and check if we are dealing with a graph # box if isinstance(box.meta, Graph): style = "choice1" else: style = "choice3" # Add the box self.add_box( box_name, inputs=[""] * (0 if box.links_from_degree == 0 else 1), outputs=[""] * (0 if box.links_to_degree == 0 else 1), active=True, style=style, graph=box.meta) # If no node position is defined used an automatic setup # based on a graph representation if self.gpositions == {}: scale = 0.0 for node in self.gnodes.values(): scale = max(node.box.boundingRect().width(), scale) scale = max(node.box.boundingRect().height(), scale) scale *= 4 box_positions = self.graph.layout(scale=scale) for node_name, node_pos in box_positions.items(): self.gnodes[node_name].setPos(QtCore.QPointF(*node_pos)) # Create the links between the boxes for from_box_name, to_box_name in self.graph._links: self.add_link("{0}.->{1}.".format(from_box_name, to_box_name)) def parse_link_description(self, linkdesc): """ Parse a link description. Parameters ---------- linkdesc: string (mandatory) link representation with the source and destination separated by '->' and control desriptions of the form '<box_name>.<control_name>' or '<control_name>' for graph input or output controls. Returns ------- src_control: 2-uplet the source control representation (box_name, control_name). dest_control: 2-uplet the destination control representation (box_name, control_name). """ # Parse description srcdesc, destdesc = linkdesc.split("->") src_control = srcdesc.split(".") dest_control = destdesc.split(".") # Deal with graph input and output controls if len(src_control) == 1: src_control.insert(0, "inputs") if len(dest_control) == 1: dest_control.insert(0, "outputs") return tuple(src_control), tuple(dest_control) def add_box(self, name, inputs, outputs, active=True, style=None, graph=None): """ Add a box in the graph representation. Parameters ---------- name: string a name for the box. inputs: list of str the box input controls. outputs: list of str the box output controls. active: bool, default True a special color will be applied on the box rendering depending of this parameter. style: string, default None the style that will be applied to tune the box rendering. graph: Graph, default None the sub-graph item. """ # Create the node widget that represents the box box_node = Node(name, inputs, outputs, active=active, style=style, graph=graph) # Update the scene self.addItem(box_node) node_position = self.gpositions.get(name) if node_position is not None: box_node.setPos(node_position) self.gnodes[name] = box_node def add_link(self, linkdesc): """ Define a link between two nodes in the graph. Parameters ---------- linkdesc: string (mandatory) link representation with the source and destination separated by '->' and control desriptions of the form '<box_name>.<control_name>' or '<control_name>' for graph input or output controls. """ # Parse the link description src_control, dest_control = self.parse_link_description(linkdesc) # Get the source and destination nodes/controls src_gnode = self.gnodes[src_control[0]] dest_gnode = self.gnodes[dest_control[0]] src_gcontrol = src_control[1] dest_gcontrol = dest_control[1] # Create the link src_control_glyph = src_gnode.output_controls[src_gcontrol][0] dest_control_glyph = dest_gnode.input_controls[dest_gcontrol][0] glink = Link( src_gnode.mapToScene(src_control_glyph.get_control_point()), dest_gnode.mapToScene(dest_control_glyph.get_control_point())) # Update the scene self.addItem(glink) self.glinks[linkdesc] = glink def keyPressEvent(self, event): """ Display the graph box positions when the 'p' key is pressed. """ super(GraphScene, self).keyPressEvent(event) if not event.isAccepted() and event.key() == QtCore.Qt.Key_P: event.accept() posdict = dict([(key, (value.x(), value.y())) for key, value in self.gpositions.items()]) pprint(posdict) def helpEvent(self, event): """ Display tooltips on controls and links. """ item = self.itemAt(event.scenePos()) if isinstance(item, Control): item.setToolTip("type: {0} - optional: {1}".format( item.control.__class__.__name__, item.optional)) super(GraphScene, self).helpEvent(event) class GraphView(QtWidgets.QGraphicsView): """ Graph representation (using boxes and arrows). Based on Qt QGraphicsView, this can be used as a Qt QWidget. Qt signals are emitted: * on a double click on a sub-graph box to display the sub-graph. If 'ctrl' is pressed a new window is created otherwise the view is embedded. * on the wheel to zoom in or zoom out. * on the kewboard 'p' key to display the box node positions. Attributes ---------- scene: GraphScene the main scene. """ # Signal emitted when a sub graph has to be open subgraph_clicked = QtCore.Signal(str, Graph, QtCore.Qt.KeyboardModifiers) def __init__(self, graph, parent=None): """ Initilaize the GraphView class. Parameters ---------- graph: Graph graph to be displayed. parent: QWidget, default None parent widget. """ # Inheritance super(GraphView, self).__init__(parent) # Class parameters self.scene = None # Check that we have a graph if not isinstance(graph, Graph): raise Exception("'{0}' is not a valid graph.".format(graph)) # Create the graph representing. self.set_graph(graph) def set_graph(self, graph): """ Assigns a new graph to the view. Parameters ---------- graph: Graph graph to be displayed. """ # Define the graph box positions if hasattr(graph, "_box_positions"): box_positions = dict( (box_name, QtCore.QPointF(*box_position)) for box_name, box_position in graph._box_positions.items()) else: box_positions = {} # Create the scene self.scene = GraphScene(graph, self) self.scene.gpositions = box_positions self.scene.draw() # Update the current view self.setWindowTitle("Graph representation") self.setScene(self.scene) # Try to initialize the current view scale factor if hasattr(graph, "_scale"): self.scale(graph.scale, graph.scale) # Define signals self.scene.subgraph_clicked.connect(self.subgraph_clicked) self.scene.subgraph_clicked.connect(self.display_subgraph) def zoom_in(self): """ Zoom the view in by applying a 1.2 zoom factor. """ self.scale(1.2, 1.2) def zoom_out(self): """ Zoom the view out by applying a 1 / 1.2 zoom factor. """ self.scale(1.0 / 1.2, 1.0 / 1.2) def display_subgraph(self, node_name, graph, modifiers): """ Event to display the selected sub-graph. If 'ctrl' is pressed the a new window is created, otherwise the new view will be embedded in its parent node box. Parameters ---------- node_name: str the node name. graph: Graph the sub-graph box to display. """ # Open a new window if modifiers & QtCore.Qt.ControlModifier: view = GraphView(graph) QtCore.QObject.setParent(view, self.window()) view.setAttribute(QtCore.Qt.WA_DeleteOnClose) view.setWindowTitle(node_name) view.show() # Embedded sub-graph inside its parent node else: node = self.scene.gnodes.get(node_name) node.add_subgraph_view(graph) def wheelEvent(self, event): """ Change the scene zoom factor. """ item = self.itemAt(event.pos()) if not isinstance(item, QtGui.QGraphicsProxyWidget): if event.delta() < 0: self.zoom_out() else: self.zoom_in() event.accept() else: super(GraphView, self).wheelEvent(event)

Follow us

© 2019, pynet developers .
Inspired by AZMIND template.