Skip to content

Commit 01c831e

Browse files
authored
[generate_ast] providing AST args, and fall back to api._codegen when output is a tuple (#481)
1 parent d4646fb commit 01c831e

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

helion/_compiler/generate_ast.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -289,42 +289,47 @@ def visit_Call(self, node: ast.Call) -> ast.AST:
289289
block_info.from_config(self.device_function.config)
290290
)
291291
)
292-
elif isinstance(type_info, SequenceType):
292+
elif isinstance(type_info, SequenceType) and all(
293+
isinstance(x, TileIndexType) for x in type_info.unpack()
294+
):
293295
values = type_info.unpack()
294-
if all(isinstance(x, TileIndexType) for x in values):
295-
block_infos = [env.block_sizes[x.block_id] for x in values] # pyright: ignore[reportAttributeAccessIssue]
296-
return expr_from_string(
297-
self.host_function.literal_expr(
298-
[
299-
x.from_config(self.device_function.config)
300-
for x in block_infos
301-
]
302-
)
296+
block_infos = [env.block_sizes[x.block_id] for x in values] # pyright: ignore[reportAttributeAccessIssue]
297+
return expr_from_string(
298+
self.host_function.literal_expr(
299+
[x.from_config(self.device_function.config) for x in block_infos]
303300
)
301+
)
304302
elif (
305303
isinstance(fn_type_info := func_node._type_info, CallableType)
306304
and is_api_func(api := fn_type_info.value)
307305
and api._codegen is not None
308306
):
307+
ast_args = []
308+
ast_kwargs = {}
309309
proxy_args = []
310310
proxy_kwargs = {}
311311
for arg in node.args:
312312
assert not isinstance(arg, ast.Starred)
313313
assert isinstance(arg, ExtendedAST)
314314
assert arg._type_info is not None
315+
ast_args.append(arg)
315316
proxy_args.append(arg._type_info.proxy())
316317
for kwarg in node.keywords:
317318
assert kwarg.arg is not None
318319
assert isinstance(kwarg.value, ExtendedAST)
319320
assert kwarg.value._type_info is not None
321+
ast_kwargs[kwarg.arg] = kwarg.value
320322
proxy_kwargs[kwarg.arg] = kwarg.value._type_info.proxy()
323+
ast_params = api._signature.bind(*ast_args, **ast_kwargs)
321324
proxy_params = api._signature.bind(*proxy_args, **proxy_kwargs)
325+
ast_params.apply_defaults()
322326
proxy_params.apply_defaults()
323327
return api._codegen( # pyright: ignore[reportReturnType]
324328
CodegenState(
325329
self,
326330
None,
327331
proxy_args=[*proxy_params.arguments.values()],
332+
ast_args=[*ast_params.arguments.values()],
328333
)
329334
)
330335
return self.generic_visit(node)

0 commit comments

Comments
 (0)