diff --git a/xircuits/compiler/generator.py b/xircuits/compiler/generator.py index 90715c82..f383b375 100644 --- a/xircuits/compiler/generator.py +++ b/xircuits/compiler/generator.py @@ -306,6 +306,9 @@ def _generate_main(self, flow_name): main = ast.parse(""" def main(args): import pprint + import asyncio + from asgiref.sync import sync_to_async + ctx = {} ctx['args'] = args flow = %s() @@ -325,9 +328,17 @@ def main(args): tpl = "flow.%s.value = args.%s" % (arg_name, arg_name) body.extend(ast.parse(tpl).body) - body.extend(ast.parse(""" -flow.do(ctx) -""").body) + execute = ast.parse(""" +@sync_to_async +def execute(): + try: + flow.do(ctx) + except: + import traceback + traceback.print_exc() + raise +""").body[0] + execute_body = execute.body # Print out the output values for i, port in enumerate(p for p in finish_node.ports if p.dataType == 'dynalist'): @@ -335,11 +346,26 @@ def main(args): if i > 0: port_name = "%s_%s" % (port_name, i) - body.extend(ast.parse(""" + execute_body.extend(ast.parse(""" print("%s:") pprint.pprint(flow.%s.value) """ % (port_name, port_name)).body) + body.append(execute) + body.extend(ast.parse(""" +event_loop = None +try: + event_loop = asyncio.get_running_loop() +except RuntimeError: + pass + +if event_loop: + event_loop.create_task(execute()) +else: + asyncio.run(execute()) + +""").body) + return [main] def _build_node_set(self):