11from collections import deque , namedtuple
22from io import BytesIO
3+ from typing import Dict , Literal
34
45from csp .impl .wiring .runtime import build_graph
56
6- NODE = namedtuple ("NODE" , ["name" , "label" , "color" , "shape" ])
7- EDGE = namedtuple ("EDGE" , ["start" , "end" ])
7+ _KIND = Literal ["output" , "input" , "" ]
8+ _NODE = namedtuple ("NODE" , ["name" , "label" , "kind" ])
9+ _EDGE = namedtuple ("EDGE" , ["start" , "end" ])
810
11+ _GRAPHVIZ_COLORMAP : Dict [_KIND , str ] = {"output" : "red" , "input" : "cadetblue1" , "" : "white" }
912
10- def _build_graphviz_graph (graph_func , * args , ** kwargs ):
11- from graphviz import Digraph
13+ _GRAPHVIZ_SHAPEMAP : Dict [_KIND , str ] = {"output" : "rarrow" , "input" : "rarrow" , "" : "box" }
14+
15+ _DAGRED3_COLORMAP : Dict [_KIND , str ] = {
16+ "output" : "red" ,
17+ "input" : "#98f5ff" ,
18+ "" : "lightgrey" ,
19+ }
20+ _DAGRED3_SHAPEMAP : Dict [_KIND , str ] = {"output" : "diamond" , "input" : "diamond" , "" : "rect" }
21+
22+ _NOTEBOOK_KIND = Literal ["" , "terminal" , "notebook" ]
23+
24+ __all__ = (
25+ "generate_graph" ,
26+ "show_graph_pil" ,
27+ "show_graph_graphviz" ,
28+ "show_graph_widget" ,
29+ "show_graph" ,
30+ )
31+
32+
33+ def _notebook_kind () -> _NOTEBOOK_KIND :
34+ try :
35+ from IPython import get_ipython
1236
37+ shell = get_ipython ().__class__ .__name__
38+ if shell == "ZMQInteractiveShell" :
39+ return "notebook"
40+ elif shell == "TerminalInteractiveShell" :
41+ return "terminal"
42+ else :
43+ return ""
44+ except ImportError :
45+ return ""
46+ except NameError :
47+ return ""
48+
49+
50+ def _build_graph_for_viz (graph_func , * args , ** kwargs ):
1351 graph = build_graph (graph_func , * args , ** kwargs )
14- digraph = Digraph (strict = True )
15- digraph .attr (rankdir = "LR" , size = "150,150" )
1652
1753 rootnames = set ()
1854 q = deque ()
@@ -29,56 +65,142 @@ def _build_graphviz_graph(graph_func, *args, **kwargs):
2965 name = str (id (nodedef ))
3066 visited .add (nodedef )
3167 if name in rootnames : # output node
32- color = "red"
33- shape = "rarrow"
68+ kind = "output"
3469 elif not sum (1 for _ in nodedef .ts_inputs ()): # input node
35- color = "cadetblue1"
36- shape = "rarrow"
70+ kind = "input"
3771 else :
38- color = "white"
39- shape = "box"
72+ kind = ""
4073
4174 label = nodedef .__name__ if hasattr (nodedef , "__name__" ) else type (nodedef ).__name__
42- nodes .append (NODE (name = name , label = label , color = color , shape = shape ))
75+ nodes .append (_NODE (name = name , label = label , kind = kind ))
4376
4477 for input in nodedef .ts_inputs ():
4578 if input [1 ].nodedef not in visited :
4679 q .append (input [1 ].nodedef )
47- edges .append (EDGE (start = str (id (input [1 ].nodedef )), end = name ))
80+ edges .append (_EDGE (start = str (id (input [1 ].nodedef )), end = name ))
81+ return nodes , edges
82+
83+
84+ def _build_graphviz_graph (graph_func , * args , ** kwargs ):
85+ from graphviz import Digraph
86+
87+ nodes , edges = _build_graph_for_viz (graph_func = graph_func , * args , ** kwargs )
88+
89+ digraph = Digraph (strict = True )
90+ digraph .attr (rankdir = "LR" , size = "150,150" )
4891
4992 for node in nodes :
5093 digraph .node (
5194 node .name ,
5295 node .label ,
5396 style = "filled" ,
54- fillcolor = node .color ,
55- shape = node .shape ,
97+ fillcolor = _GRAPHVIZ_COLORMAP [ node .kind ] ,
98+ shape = _GRAPHVIZ_SHAPEMAP [ node .kind ] ,
5699 )
57100 for edge in edges :
58101 digraph .edge (edge .start , edge .end )
59102
60103 return digraph
61104
62105
106+ def _graphviz_to_buffer (digraph , image_format = "png" ) -> BytesIO :
107+ from graphviz import ExecutableNotFound
108+
109+ digraph .format = image_format
110+ buffer = BytesIO ()
111+
112+ try :
113+ buffer .write (digraph .pipe ())
114+ buffer .seek (0 )
115+ return buffer
116+ except ExecutableNotFound as exc :
117+ raise ModuleNotFoundError (
118+ "Must install graphviz and have `dot` available on your PATH. See https://graphviz.org for installation instructions"
119+ ) from exc
120+
121+
63122def generate_graph (graph_func , * args , image_format = "png" , ** kwargs ):
64123 """Generate a BytesIO image representation of the given graph"""
65124 digraph = _build_graphviz_graph (graph_func , * args , ** kwargs )
66- digraph .format = image_format
67- buffer = BytesIO ()
68- buffer .write (digraph .pipe ())
69- buffer .seek (0 )
70- return buffer
125+ return _graphviz_to_buffer (digraph = digraph , image_format = image_format )
71126
72127
73- def show_graph (graph_func , * args , graph_filename = None , ** kwargs ):
128+ def show_graph_pil (graph_func , * args , ** kwargs ):
129+ buffer = generate_graph (graph_func , * args , image_format = "png" , ** kwargs )
130+ try :
131+ from PIL import Image
132+ except ImportError :
133+ raise ModuleNotFoundError (
134+ "csp requires `pillow` to display images. Install `pillow` with your python package manager, or pass `graph_filename` to generate a file output."
135+ )
136+ image = Image .open (buffer )
137+ image .show ()
138+
139+
140+ def show_graph_graphviz (graph_func , * args , graph_filename = None , interactive = False , ** kwargs ):
141+ # extract the format of the image
74142 image_format = graph_filename .split ("." )[- 1 ] if graph_filename else "png"
75- buffer = generate_graph (graph_func , * args , image_format = image_format , ** kwargs )
76143
77- if graph_filename :
78- with open (graph_filename , "wb" ) as f :
79- f .write (buffer .read ())
144+ # Generate graph with graphviz
145+ digraph = _build_graphviz_graph (graph_func , * args , ** kwargs )
146+
147+ # if we're in a notebook, return it directly for rendering
148+ if interactive :
149+ return digraph
150+
151+ # otherwise output to file
152+ buffer = _graphviz_to_buffer (digraph = digraph , image_format = image_format )
153+ with open (graph_filename , "wb" ) as f :
154+ f .write (buffer .read ())
155+ return digraph
156+
157+
158+ def show_graph_widget (graph_func , * args , ** kwargs ):
159+ try :
160+ import ipydagred3
161+ except ImportError :
162+ raise ModuleNotFoundError (
163+ "csp requires `ipydagred3` to display graph widget. Install `ipydagred3` with your python package manager, or pass `graph_filename` to generate a file output."
164+ )
165+
166+ nodes , edges = _build_graph_for_viz (graph_func = graph_func , * args , ** kwargs )
167+
168+ graph = ipydagred3 .Graph (directed = True , attrs = dict (rankdir = "LR" ))
169+
170+ for node in nodes :
171+ graph .addNode (
172+ ipydagred3 .Node (
173+ name = node .name ,
174+ label = node .label ,
175+ shape = _DAGRED3_SHAPEMAP [node .kind ],
176+ style = f"fill: { _DAGRED3_COLORMAP [node .kind ]} " ,
177+ )
178+ )
179+ for edge in edges :
180+ graph .addEdge (edge .start , edge .end )
181+ return ipydagred3 .DagreD3Widget (graph = graph )
182+
183+
184+ def show_graph (graph_func , * args , graph_filename = None , ** kwargs ):
185+ # check if we're in jupyter
186+ if _notebook_kind () == "notebook" :
187+ _HAVE_INTERACTIVE = True
80188 else :
81- from PIL import Image
189+ _HAVE_INTERACTIVE = False
190+
191+ # display graph via pillow or ipydagred3
192+ if graph_filename in (None , "widget" ):
193+ if graph_filename == "widget" and not _HAVE_INTERACTIVE :
194+ raise RuntimeError ("Interactive graph viewer only works in Jupyter." )
195+
196+ # render with ipydagred3
197+ if graph_filename == "widget" :
198+ return show_graph_widget (graph_func , * args , ** kwargs )
199+
200+ # render with pillow
201+ return show_graph_pil (graph_func , * args , ** kwargs )
82202
83- image = Image .open (buffer )
84- image .show ()
203+ # TODO we can show graphviz in jupyter without a filename, but preserving existing behavior for now
204+ return show_graph_graphviz (
205+ graph_func , * args , graph_filename = graph_filename , interactive = _HAVE_INTERACTIVE , ** kwargs
206+ )
0 commit comments