diff --git a/conda/dev-environment-unix.yml b/conda/dev-environment-unix.yml index fa3740a5b..f503ba5ef 100644 --- a/conda/dev-environment-unix.yml +++ b/conda/dev-environment-unix.yml @@ -18,6 +18,7 @@ dependencies: - python-graphviz - gtest - httpx>=0.20,<1 + - ipydagred3 - isort>=5,<6 - libarrow=16 - libboost>=1.80.0 diff --git a/csp/__init__.py b/csp/__init__.py index a223dc7aa..50f342017 100644 --- a/csp/__init__.py +++ b/csp/__init__.py @@ -27,7 +27,7 @@ ) from csp.impl.wiring.context import clear_global_context, new_global_context from csp.math import * -from csp.showgraph import show_graph +from csp.showgraph import * from . import stats diff --git a/csp/dataframe.py b/csp/dataframe.py index 8aba3814c..8bc7820e2 100644 --- a/csp/dataframe.py +++ b/csp/dataframe.py @@ -3,6 +3,7 @@ import csp.baselib from csp.impl.wiring.edge import Edge +from csp.showgraph import show_graph # Lazy declaration below to avoid perspective import RealtimePerspectiveWidget = None @@ -143,12 +144,7 @@ def _eval(self, starttime: datetime, endtime: datetime = None, realtime: bool = return csp.run(self._eval_graph, starttime=starttime, endtime=endtime, realtime=realtime) def show_graph(self): - from PIL import Image - - import csp.showgraph - - buffer = csp.showgraph.generate_graph(self._eval_graph) - return Image.open(buffer) + show_graph(self._eval_graph, graph_filename=None) def to_pandas(self, starttime: datetime, endtime: datetime): import pandas @@ -222,7 +218,9 @@ def join(self): self._runner.join() except ImportError: - raise ImportError("eval_perspective requires perspective-python installed") + raise ModuleNotFoundError( + "eval_perspective requires perspective-python installed. See https://perspective.finos.org for installation instructions." + ) if not realtime: df = self.to_pandas(starttime, endtime) diff --git a/csp/impl/pandas_accessor.py b/csp/impl/pandas_accessor.py index 3d7a0eb1b..8c3c3cc17 100644 --- a/csp/impl/pandas_accessor.py +++ b/csp/impl/pandas_accessor.py @@ -10,6 +10,7 @@ from csp.impl.pandas_ext_type import TsDtype, is_csp_type from csp.impl.struct import define_nested_struct from csp.impl.wiring.edge import Edge +from csp.showgraph import show_graph T = TypeVar("T") @@ -375,12 +376,7 @@ def show_graph(self): """Show the graph corresponding to the evaluation of all the edges. For large series, this may be very large, so it may be helpful to call .head() first. """ - from PIL import Image - - import csp.showgraph - - buffer = csp.showgraph.generate_graph(self._eval_graph, "png") - return Image.open(buffer) + return show_graph(self._eval_graph, graph_filename=None) @register_series_accessor("to_csp") @@ -626,12 +622,7 @@ def show_graph(self): """Show the graph corresponding to the evaluation of all the edges. For large series, this may be very large, so it may be helpful to call .head() first. """ - from PIL import Image - - import csp.showgraph - - buffer = csp.showgraph.generate_graph(self._eval_graph, "png") - return Image.open(buffer) + show_graph(self._eval_graph, graph_filename=None) @register_dataframe_accessor("to_csp") diff --git a/csp/profiler.py b/csp/profiler.py index 56f74825d..965207b10 100644 --- a/csp/profiler.py +++ b/csp/profiler.py @@ -382,7 +382,7 @@ def initialize(self, adapter: GenericPushAdapter, display_graphs: bool): try: import matplotlib # noqa: F401 except ImportError: - raise Exception("You must have matplotlib installed to display profiling data graphs.") + raise ModuleNotFoundError("You must have matplotlib installed to display profiling data graphs.") def get(self): try: @@ -478,7 +478,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback): def init_profiler(self): if self.http_port is not None: if not HAS_TORNADO: - raise Exception("You must have tornado installed to use the HTTP profiling extension.") + raise ModuleNotFoundError("You must have tornado installed to use the HTTP profiling extension.") adapter = GenericPushAdapter(Future) application = tornado.web.Application( diff --git a/csp/showgraph.py b/csp/showgraph.py index 74440c447..34fb1cc53 100644 --- a/csp/showgraph.py +++ b/csp/showgraph.py @@ -1,18 +1,54 @@ from collections import deque, namedtuple from io import BytesIO +from typing import Dict, Literal from csp.impl.wiring.runtime import build_graph -NODE = namedtuple("NODE", ["name", "label", "color", "shape"]) -EDGE = namedtuple("EDGE", ["start", "end"]) +_KIND = Literal["output", "input", ""] +_NODE = namedtuple("NODE", ["name", "label", "kind"]) +_EDGE = namedtuple("EDGE", ["start", "end"]) +_GRAPHVIZ_COLORMAP: Dict[_KIND, str] = {"output": "red", "input": "cadetblue1", "": "white"} -def _build_graphviz_graph(graph_func, *args, **kwargs): - from graphviz import Digraph +_GRAPHVIZ_SHAPEMAP: Dict[_KIND, str] = {"output": "rarrow", "input": "rarrow", "": "box"} + +_DAGRED3_COLORMAP: Dict[_KIND, str] = { + "output": "red", + "input": "#98f5ff", + "": "lightgrey", +} +_DAGRED3_SHAPEMAP: Dict[_KIND, str] = {"output": "diamond", "input": "diamond", "": "rect"} + +_NOTEBOOK_KIND = Literal["", "terminal", "notebook"] + +__all__ = ( + "generate_graph", + "show_graph_pil", + "show_graph_graphviz", + "show_graph_widget", + "show_graph", +) + + +def _notebook_kind() -> _NOTEBOOK_KIND: + try: + from IPython import get_ipython + shell = get_ipython().__class__.__name__ + if shell == "ZMQInteractiveShell": + return "notebook" + elif shell == "TerminalInteractiveShell": + return "terminal" + else: + return "" + except ImportError: + return "" + except NameError: + return "" + + +def _build_graph_for_viz(graph_func, *args, **kwargs): graph = build_graph(graph_func, *args, **kwargs) - digraph = Digraph(strict=True) - digraph.attr(rankdir="LR", size="150,150") rootnames = set() q = deque() @@ -29,30 +65,37 @@ def _build_graphviz_graph(graph_func, *args, **kwargs): name = str(id(nodedef)) visited.add(nodedef) if name in rootnames: # output node - color = "red" - shape = "rarrow" + kind = "output" elif not sum(1 for _ in nodedef.ts_inputs()): # input node - color = "cadetblue1" - shape = "rarrow" + kind = "input" else: - color = "white" - shape = "box" + kind = "" label = nodedef.__name__ if hasattr(nodedef, "__name__") else type(nodedef).__name__ - nodes.append(NODE(name=name, label=label, color=color, shape=shape)) + nodes.append(_NODE(name=name, label=label, kind=kind)) for input in nodedef.ts_inputs(): if input[1].nodedef not in visited: q.append(input[1].nodedef) - edges.append(EDGE(start=str(id(input[1].nodedef)), end=name)) + edges.append(_EDGE(start=str(id(input[1].nodedef)), end=name)) + return nodes, edges + + +def _build_graphviz_graph(graph_func, *args, **kwargs): + from graphviz import Digraph + + nodes, edges = _build_graph_for_viz(graph_func, *args, **kwargs) + + digraph = Digraph(strict=True) + digraph.attr(rankdir="LR", size="150,150") for node in nodes: digraph.node( node.name, node.label, style="filled", - fillcolor=node.color, - shape=node.shape, + fillcolor=_GRAPHVIZ_COLORMAP[node.kind], + shape=_GRAPHVIZ_SHAPEMAP[node.kind], ) for edge in edges: digraph.edge(edge.start, edge.end) @@ -60,25 +103,96 @@ def _build_graphviz_graph(graph_func, *args, **kwargs): return digraph +def _graphviz_to_buffer(digraph, image_format="png") -> BytesIO: + from graphviz import ExecutableNotFound + + digraph.format = image_format + buffer = BytesIO() + + try: + buffer.write(digraph.pipe()) + buffer.seek(0) + return buffer + except ExecutableNotFound as exc: + raise ModuleNotFoundError( + "Must install graphviz and have `dot` available on your PATH. See https://graphviz.org for installation instructions" + ) from exc + + def generate_graph(graph_func, *args, image_format="png", **kwargs): """Generate a BytesIO image representation of the given graph""" digraph = _build_graphviz_graph(graph_func, *args, **kwargs) - digraph.format = image_format - buffer = BytesIO() - buffer.write(digraph.pipe()) - buffer.seek(0) - return buffer + return _graphviz_to_buffer(digraph=digraph, image_format=image_format) -def show_graph(graph_func, *args, graph_filename=None, **kwargs): +def show_graph_pil(graph_func, *args, **kwargs): + buffer = generate_graph(graph_func, *args, image_format="png", **kwargs) + try: + from PIL import Image + except ImportError: + raise ModuleNotFoundError( + "csp requires `pillow` to display images. Install `pillow` with your python package manager, or pass `graph_filename` to generate a file output." + ) + image = Image.open(buffer) + image.show() + + +def show_graph_graphviz(graph_func, *args, graph_filename=None, **kwargs): + # extract the format of the image image_format = graph_filename.split(".")[-1] if graph_filename else "png" - buffer = generate_graph(graph_func, *args, image_format=image_format, **kwargs) + + # Generate graph with graphviz + digraph = _build_graphviz_graph(graph_func, *args, **kwargs) if graph_filename: + # output to file + buffer = _graphviz_to_buffer(digraph=digraph, image_format=image_format) with open(graph_filename, "wb") as f: f.write(buffer.read()) - else: - from PIL import Image + return digraph - image = Image.open(buffer) - image.show() + +def show_graph_widget(graph_func, *args, **kwargs): + try: + import ipydagred3 + except ImportError: + raise ModuleNotFoundError( + "csp requires `ipydagred3` to display graph widget. Install `ipydagred3` with your python package manager, or pass `graph_filename` to generate a file output." + ) + + nodes, edges = _build_graph_for_viz(graph_func=graph_func, *args, **kwargs) + + graph = ipydagred3.Graph(directed=True, attrs=dict(rankdir="LR")) + + for node in nodes: + graph.addNode( + ipydagred3.Node( + name=node.name, + label=node.label, + shape=_DAGRED3_SHAPEMAP[node.kind], + style=f"fill: {_DAGRED3_COLORMAP[node.kind]}", + ) + ) + for edge in edges: + graph.addEdge(edge.start, edge.end) + return ipydagred3.DagreD3Widget(graph=graph) + + +def show_graph(graph_func, *args, graph_filename=None, **kwargs): + # check if we're in jupyter + if _notebook_kind() == "notebook": + _HAVE_INTERACTIVE = True + else: + _HAVE_INTERACTIVE = False + + if graph_filename == "widget" and not _HAVE_INTERACTIVE: + # widget only works in Jupyter for now + raise RuntimeError("Interactive graph viewer only works in Jupyter.") + elif graph_filename == "widget": + # render with ipydagred3 + return show_graph_widget(graph_func, *args, **kwargs) + elif graph_filename in ("", None) and not _HAVE_INTERACTIVE: + # render with pillow + return show_graph_pil(graph_func, *args, **kwargs) + # render with graphviz + return show_graph_graphviz(graph_func, *args, graph_filename=graph_filename, **kwargs) diff --git a/examples/98_just_for_fun/e1_csp_nand_computer.py b/examples/98_just_for_fun/e1_csp_nand_computer.py index 378663f9f..f523b5e07 100644 --- a/examples/98_just_for_fun/e1_csp_nand_computer.py +++ b/examples/98_just_for_fun/e1_csp_nand_computer.py @@ -99,7 +99,9 @@ def my_graph(bits: int = 16): csp.print("y", basket_to_number(y)) csp.print("x_bits", basket_to_bitstring(x)) csp.print("y_bits", basket_to_bitstring(y)) + add = addInt(x, y) + csp.print("x+y", basket_to_number(add)) csp.print("x+y_bits", basket_to_bitstring(add)) @@ -107,6 +109,7 @@ def my_graph(bits: int = 16): def main(): # Show graph with 4-bit ints to limit size csp.showgraph.show_graph(my_graph, 4) + csp.run(my_graph, starttime=datetime(2022, 6, 24))