@@ -289,42 +289,47 @@ def visit_Call(self, node: ast.Call) -> ast.AST:
289
289
block_info .from_config (self .device_function .config )
290
290
)
291
291
)
292
- elif isinstance (type_info , SequenceType ):
292
+ elif isinstance (type_info , SequenceType ) and all (
293
+ isinstance (x , TileIndexType ) for x in type_info .unpack ()
294
+ ):
293
295
values = type_info .unpack ()
294
- if all (isinstance (x , TileIndexType ) for x in values ):
295
- block_infos = [env .block_sizes [x .block_id ] for x in values ] # pyright: ignore[reportAttributeAccessIssue]
296
- return expr_from_string (
297
- self .host_function .literal_expr (
298
- [
299
- x .from_config (self .device_function .config )
300
- for x in block_infos
301
- ]
302
- )
296
+ block_infos = [env .block_sizes [x .block_id ] for x in values ] # pyright: ignore[reportAttributeAccessIssue]
297
+ return expr_from_string (
298
+ self .host_function .literal_expr (
299
+ [x .from_config (self .device_function .config ) for x in block_infos ]
303
300
)
301
+ )
304
302
elif (
305
303
isinstance (fn_type_info := func_node ._type_info , CallableType )
306
304
and is_api_func (api := fn_type_info .value )
307
305
and api ._codegen is not None
308
306
):
307
+ ast_args = []
308
+ ast_kwargs = {}
309
309
proxy_args = []
310
310
proxy_kwargs = {}
311
311
for arg in node .args :
312
312
assert not isinstance (arg , ast .Starred )
313
313
assert isinstance (arg , ExtendedAST )
314
314
assert arg ._type_info is not None
315
+ ast_args .append (arg )
315
316
proxy_args .append (arg ._type_info .proxy ())
316
317
for kwarg in node .keywords :
317
318
assert kwarg .arg is not None
318
319
assert isinstance (kwarg .value , ExtendedAST )
319
320
assert kwarg .value ._type_info is not None
321
+ ast_kwargs [kwarg .arg ] = kwarg .value
320
322
proxy_kwargs [kwarg .arg ] = kwarg .value ._type_info .proxy ()
323
+ ast_params = api ._signature .bind (* ast_args , ** ast_kwargs )
321
324
proxy_params = api ._signature .bind (* proxy_args , ** proxy_kwargs )
325
+ ast_params .apply_defaults ()
322
326
proxy_params .apply_defaults ()
323
327
return api ._codegen ( # pyright: ignore[reportReturnType]
324
328
CodegenState (
325
329
self ,
326
330
None ,
327
331
proxy_args = [* proxy_params .arguments .values ()],
332
+ ast_args = [* ast_params .arguments .values ()],
328
333
)
329
334
)
330
335
return self .generic_visit (node )
0 commit comments