diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index 8ad8cb97..b48e420a 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -289,42 +289,47 @@ def visit_Call(self, node: ast.Call) -> ast.AST: block_info.from_config(self.device_function.config) ) ) - elif isinstance(type_info, SequenceType): + elif isinstance(type_info, SequenceType) and all( + isinstance(x, TileIndexType) for x in type_info.unpack() + ): values = type_info.unpack() - if all(isinstance(x, TileIndexType) for x in values): - block_infos = [env.block_sizes[x.block_id] for x in values] # pyright: ignore[reportAttributeAccessIssue] - return expr_from_string( - self.host_function.literal_expr( - [ - x.from_config(self.device_function.config) - for x in block_infos - ] - ) + block_infos = [env.block_sizes[x.block_id] for x in values] # pyright: ignore[reportAttributeAccessIssue] + return expr_from_string( + self.host_function.literal_expr( + [x.from_config(self.device_function.config) for x in block_infos] ) + ) elif ( isinstance(fn_type_info := func_node._type_info, CallableType) and is_api_func(api := fn_type_info.value) and api._codegen is not None ): + ast_args = [] + ast_kwargs = {} proxy_args = [] proxy_kwargs = {} for arg in node.args: assert not isinstance(arg, ast.Starred) assert isinstance(arg, ExtendedAST) assert arg._type_info is not None + ast_args.append(arg) proxy_args.append(arg._type_info.proxy()) for kwarg in node.keywords: assert kwarg.arg is not None assert isinstance(kwarg.value, ExtendedAST) assert kwarg.value._type_info is not None + ast_kwargs[kwarg.arg] = kwarg.value proxy_kwargs[kwarg.arg] = kwarg.value._type_info.proxy() + ast_params = api._signature.bind(*ast_args, **ast_kwargs) proxy_params = api._signature.bind(*proxy_args, **proxy_kwargs) + ast_params.apply_defaults() proxy_params.apply_defaults() return api._codegen( # pyright: ignore[reportReturnType] CodegenState( self, None, proxy_args=[*proxy_params.arguments.values()], + ast_args=[*ast_params.arguments.values()], ) ) return self.generic_visit(node)