@@ -17,8 +17,9 @@ const Allocator = std.mem.Allocator;
1717const Ident = base .Ident ;
1818const Region = base .Region ;
1919const ModuleWork = base .ModuleWork ;
20-
20+ const Func = types_mod . Func ;
2121const Var = types_mod .Var ;
22+ const Content = types_mod .Content ;
2223const exitOnOom = collections .utils .exitOnOom ;
2324
2425const Self = @This ();
@@ -226,102 +227,92 @@ pub fn checkExpr(self: *Self, expr_idx: CIR.Expr.Idx) std.mem.Allocator.Error!bo
226227
227228 if (all_exprs .len == 0 ) return false ; // No function to call
228229
229- // First expression is the function being called
230+ // First expression is the function being called; the rest are args.
230231 const func_expr_idx = all_exprs [0 ];
231- does_fx = try self .checkExpr (func_expr_idx ) or does_fx ; // There could be some effects done while creating this fn on the fly.
232-
233- // Check if the function being called is effectful by looking at its type
234- const func_type_var = @as (Var , @enumFromInt (@intFromEnum (func_expr_idx )));
235- const resolved_func = self .types .resolveVar (func_type_var );
236-
237- // Rest are arguments
232+ does_fx = try self .checkExpr (func_expr_idx ) or does_fx ; // func_expr could be effectful, e.g. `(mk_fn!())(arg)`
238233 const call_args = all_exprs [1.. ];
239234 for (call_args ) | arg_expr_idx | {
235+ // Each arg could also be effectful, e.g. `fn(mk_arg!(), mk_arg!())`
240236 does_fx = try self .checkExpr (arg_expr_idx ) or does_fx ;
241237 }
242238
243- // Handle function calls, being careful about runtime errors in the function position
239+ // Don't try to unify with the function if the function is a runtime error.
244240 if (self .can_ir .store .getExpr (func_expr_idx ) != .e_runtime_error ) {
245241 const call_var = @as (Var , @enumFromInt (@intFromEnum (expr_idx )));
246242 const func_var = @as (Var , @enumFromInt (@intFromEnum (func_expr_idx )));
243+ const resolved_func = self .types .resolveVar (func_var );
247244
248245 // Check if this is an annotated function that needs instantiation
249246 // We only instantiate if the function actually contains type variables
250- const needs_instantiation = switch (resolved_func .desc .content ) {
247+ var current_func_var = func_var ;
248+ var current_content = resolved_func .desc .content ;
249+
250+ content_switch : switch (current_content ) {
251251 .structure = > | flat_type | switch (flat_type ) {
252- .fn_effectful , .fn_pure , .fn_unbound = > self .types .needsInstantiation (func_var ),
253- else = > false ,
254- },
255- .alias = > self .types .needsInstantiation (func_var ),
256- else = > false ,
257- };
258-
259- if (needs_instantiation ) {
260- // Instantiate the function type to get fresh variables while preserving structure
261- const instantiated_func = try instantiate .instantiateVar (self .types , func_var , self .gpa );
262- const resolved_inst = self .types .resolveVar (instantiated_func );
263-
264- // Extract the instantiated function structure
265- const func_struct = switch (resolved_inst .desc .content ) {
266- .structure = > | flat_type | switch (flat_type ) {
267- .fn_effectful = > | func | blk : {
268- does_fx = true ;
269- break :blk func ;
270- },
271- .fn_pure , .fn_unbound = > | func | func ,
272- else = > unreachable ,
252+ .fn_effectful = > | _ | {
253+ does_fx = true ;
254+ if (self .types .needsInstantiation (current_func_var )) {
255+ const instantiated_var = try instantiate .instantiateVar (self .types , current_func_var , self .gpa );
256+ const resolved_inst = self .types .resolveVar (instantiated_var );
257+ std .debug .assert (resolved_inst .desc .content == .structure );
258+ std .debug .assert (resolved_inst .desc .content .structure == .fn_effectful );
259+ const inst_func = resolved_inst .desc .content .structure .fn_effectful ;
260+ try self .unifyCallWithFunc (call_var , inst_func , call_args , func_var );
261+ return does_fx ;
262+ }
273263 },
274- .alias = > | alias | blk : {
275- // Resolve through the alias to get the function
276- const backing_var = alias .getBackingVar (instantiated_func );
277- const alias_resolved = self .types .resolveVar (backing_var );
278- break :blk switch (alias_resolved .desc .content ) {
279- .structure = > | flat_type | switch (flat_type ) {
280- .fn_effectful = > | func | inner : {
281- does_fx = true ;
282- break :inner func ;
283- },
284- .fn_pure , .fn_unbound = > | func | func ,
285- else = > unreachable ,
286- },
287- else = > unreachable ,
288- };
264+ .fn_pure = > | _ | {
265+ if (self .types .needsInstantiation (current_func_var )) {
266+ const instantiated_var = try instantiate .instantiateVar (self .types , current_func_var , self .gpa );
267+ const resolved_inst = self .types .resolveVar (instantiated_var );
268+ std .debug .assert (resolved_inst .desc .content == .structure );
269+ std .debug .assert (resolved_inst .desc .content .structure == .fn_pure );
270+ const inst_func = resolved_inst .desc .content .structure .fn_pure ;
271+ try self .unifyCallWithFunc (call_var , inst_func , call_args , func_var );
272+ return does_fx ;
273+ }
289274 },
290- else = > unreachable ,
291- };
292-
293- // Unify instantiated argument types with actual arguments
294- const inst_args = self .types .getFuncArgsSlice (func_struct .args );
295- const arg_vars : []Var = @ptrCast (@alignCast (call_args ));
296-
297- // Only unify arguments if counts match - otherwise let the normal
298- // unification process handle the arity mismatch error
299- if (inst_args .len == arg_vars .len ) {
300- for (inst_args , arg_vars ) | inst_arg , actual_arg | {
301- _ = self .unify (inst_arg , actual_arg );
302- }
303- // The call's type is the instantiated return type
304- _ = self .unify (call_var , func_struct .ret );
305- } else {
306- // Fall back to normal unification to get proper error message
307- const func_content = self .types .mkFuncUnbound (arg_vars , call_var );
308- const expected_func_var = self .types .freshFromContent (func_content );
309- _ = self .unify (expected_func_var , func_var );
310- }
311- } else {
312- // Fall back to the old behavior for non-annotated functions
313- const arg_vars : []Var = @ptrCast (@alignCast (call_args ));
314-
315- // Create an unbound function type with the call result as return type
316- // The unification will propagate the actual return type to the call
317- //
318- // TODO: Do we need to insert a CIR placeholder node here as well?
319- // What happens if later this type variable has a problem, and we
320- // try to look up it's region in CIR?
321- const func_content = self .types .mkFuncUnbound (arg_vars , call_var );
322- const expected_func_var = self .types .freshFromContent (func_content );
323- _ = self .unify (expected_func_var , func_var );
275+ .fn_unbound = > | _ | {
276+ if (self .types .needsInstantiation (current_func_var )) {
277+ const instantiated_var = try instantiate .instantiateVar (self .types , current_func_var , self .gpa );
278+ const resolved_inst = self .types .resolveVar (instantiated_var );
279+ std .debug .assert (resolved_inst .desc .content == .structure );
280+ std .debug .assert (resolved_inst .desc .content .structure == .fn_unbound );
281+ const inst_func = resolved_inst .desc .content .structure .fn_unbound ;
282+ try self .unifyCallWithFunc (call_var , inst_func , call_args , func_var );
283+ return does_fx ;
284+ }
285+ },
286+ else = > {
287+ // Non-function structure - fall through
288+ },
289+ },
290+ .alias = > | alias | {
291+ // Resolve the alias, then continue on to the appropriate branch.
292+ // (It might be another alias, or we might be done and ready to proceed.)
293+ const backing_var = alias .getBackingVar (current_func_var );
294+ current_func_var = backing_var ;
295+ current_content = self .types .resolveVar (backing_var ).desc .content ;
296+ continue :content_switch current_content ;
297+ },
298+ else = > {
299+ // Non-structure content - fall through
300+ },
324301 }
302+
303+ // We didn't handle the function call above (either because it wasn't a function
304+ // or it didn't need instantiation), so fall back on this logic.
305+ const arg_vars : []Var = @constCast (@ptrCast (@alignCast (call_args )));
306+
307+ // Create an unbound function type with the call result as return type
308+ // The unification will propagate the actual return type to the call
309+ //
310+ // TODO: Do we need to insert a CIR placeholder node here as well?
311+ // What happens if later this type variable has a problem, and we
312+ // try to look up its region in CIR?
313+ const func_content = self .types .mkFuncUnbound (arg_vars , call_var );
314+ const expected_func_var = self .types .freshFromContent (func_content );
315+ _ = self .unify (expected_func_var , current_func_var );
325316 }
326317 },
327318 .e_record = > | e | {
@@ -515,6 +506,29 @@ pub fn checkExpr(self: *Self, expr_idx: CIR.Expr.Idx) std.mem.Allocator.Error!bo
515506 return does_fx ;
516507}
517508
509+ /// Helper function to unify a call expression with a function type
510+ fn unifyCallWithFunc (self : * Self , call_var : Var , func : types_mod.Func , call_args : []const CIR.Expr.Idx , original_func_var : Var ) std.mem.Allocator.Error ! void {
511+ // Unify instantiated argument types with actual arguments
512+ const inst_args = self .types .getFuncArgsSlice (func .args );
513+ const arg_vars : []Var = @constCast (@ptrCast (@alignCast (call_args )));
514+
515+ // Only unify arguments if counts match - otherwise let the normal
516+ // unification process handle the arity mismatch error
517+ if (inst_args .len == arg_vars .len ) {
518+ for (inst_args , arg_vars ) | inst_arg , actual_arg | {
519+ _ = self .unify (inst_arg , actual_arg );
520+ }
521+ // The call's type is the instantiated return type
522+ _ = self .unify (call_var , func .ret );
523+ } else {
524+ // Fall back to normal unification to get proper error message
525+ // Use the original func_var to avoid issues with instantiated variables in error reporting
526+ const func_content = self .types .mkFuncUnbound (arg_vars , call_var );
527+ const expected_func_var = self .types .freshFromContent (func_content );
528+ _ = self .unify (expected_func_var , original_func_var );
529+ }
530+ }
531+
518532/// Check a lambda expression with an optional expected type
519533fn checkLambdaWithExpected (self : * Self , expr_idx : CIR.Expr.Idx , lambda : anytype , expected_type : ? Var ) std.mem.Allocator.Error ! bool {
520534 const trace = tracy .trace (@src ());
0 commit comments