Skip to content

Commit 60ae226

Browse files
authored
Merge pull request #7952 from roc-lang/refactor-instantiation
Refactor instantiation
2 parents 149b464 + 2d826c9 commit 60ae226

File tree

2 files changed

+98
-84
lines changed

2 files changed

+98
-84
lines changed

src/check/check_types.zig

Lines changed: 95 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ const Allocator = std.mem.Allocator;
1717
const Ident = base.Ident;
1818
const Region = base.Region;
1919
const ModuleWork = base.ModuleWork;
20-
20+
const Func = types_mod.Func;
2121
const Var = types_mod.Var;
22+
const Content = types_mod.Content;
2223
const exitOnOom = collections.utils.exitOnOom;
2324

2425
const Self = @This();
@@ -226,102 +227,92 @@ pub fn checkExpr(self: *Self, expr_idx: CIR.Expr.Idx) std.mem.Allocator.Error!bo
226227

227228
if (all_exprs.len == 0) return false; // No function to call
228229

229-
// First expression is the function being called
230+
// First expression is the function being called; the rest are args.
230231
const func_expr_idx = all_exprs[0];
231-
does_fx = try self.checkExpr(func_expr_idx) or does_fx; // There could be some effects done while creating this fn on the fly.
232-
233-
// Check if the function being called is effectful by looking at its type
234-
const func_type_var = @as(Var, @enumFromInt(@intFromEnum(func_expr_idx)));
235-
const resolved_func = self.types.resolveVar(func_type_var);
236-
237-
// Rest are arguments
232+
does_fx = try self.checkExpr(func_expr_idx) or does_fx; // func_expr could be effectful, e.g. `(mk_fn!())(arg)`
238233
const call_args = all_exprs[1..];
239234
for (call_args) |arg_expr_idx| {
235+
// Each arg could also be effectful, e.g. `fn(mk_arg!(), mk_arg!())`
240236
does_fx = try self.checkExpr(arg_expr_idx) or does_fx;
241237
}
242238

243-
// Handle function calls, being careful about runtime errors in the function position
239+
// Don't try to unify with the function if the function is a runtime error.
244240
if (self.can_ir.store.getExpr(func_expr_idx) != .e_runtime_error) {
245241
const call_var = @as(Var, @enumFromInt(@intFromEnum(expr_idx)));
246242
const func_var = @as(Var, @enumFromInt(@intFromEnum(func_expr_idx)));
243+
const resolved_func = self.types.resolveVar(func_var);
247244

248245
// Check if this is an annotated function that needs instantiation
249246
// We only instantiate if the function actually contains type variables
250-
const needs_instantiation = switch (resolved_func.desc.content) {
247+
var current_func_var = func_var;
248+
var current_content = resolved_func.desc.content;
249+
250+
content_switch: switch (current_content) {
251251
.structure => |flat_type| switch (flat_type) {
252-
.fn_effectful, .fn_pure, .fn_unbound => self.types.needsInstantiation(func_var),
253-
else => false,
254-
},
255-
.alias => self.types.needsInstantiation(func_var),
256-
else => false,
257-
};
258-
259-
if (needs_instantiation) {
260-
// Instantiate the function type to get fresh variables while preserving structure
261-
const instantiated_func = try instantiate.instantiateVar(self.types, func_var, self.gpa);
262-
const resolved_inst = self.types.resolveVar(instantiated_func);
263-
264-
// Extract the instantiated function structure
265-
const func_struct = switch (resolved_inst.desc.content) {
266-
.structure => |flat_type| switch (flat_type) {
267-
.fn_effectful => |func| blk: {
268-
does_fx = true;
269-
break :blk func;
270-
},
271-
.fn_pure, .fn_unbound => |func| func,
272-
else => unreachable,
252+
.fn_effectful => |_| {
253+
does_fx = true;
254+
if (self.types.needsInstantiation(current_func_var)) {
255+
const instantiated_var = try instantiate.instantiateVar(self.types, current_func_var, self.gpa);
256+
const resolved_inst = self.types.resolveVar(instantiated_var);
257+
std.debug.assert(resolved_inst.desc.content == .structure);
258+
std.debug.assert(resolved_inst.desc.content.structure == .fn_effectful);
259+
const inst_func = resolved_inst.desc.content.structure.fn_effectful;
260+
try self.unifyCallWithFunc(call_var, inst_func, call_args, func_var);
261+
return does_fx;
262+
}
273263
},
274-
.alias => |alias| blk: {
275-
// Resolve through the alias to get the function
276-
const backing_var = alias.getBackingVar(instantiated_func);
277-
const alias_resolved = self.types.resolveVar(backing_var);
278-
break :blk switch (alias_resolved.desc.content) {
279-
.structure => |flat_type| switch (flat_type) {
280-
.fn_effectful => |func| inner: {
281-
does_fx = true;
282-
break :inner func;
283-
},
284-
.fn_pure, .fn_unbound => |func| func,
285-
else => unreachable,
286-
},
287-
else => unreachable,
288-
};
264+
.fn_pure => |_| {
265+
if (self.types.needsInstantiation(current_func_var)) {
266+
const instantiated_var = try instantiate.instantiateVar(self.types, current_func_var, self.gpa);
267+
const resolved_inst = self.types.resolveVar(instantiated_var);
268+
std.debug.assert(resolved_inst.desc.content == .structure);
269+
std.debug.assert(resolved_inst.desc.content.structure == .fn_pure);
270+
const inst_func = resolved_inst.desc.content.structure.fn_pure;
271+
try self.unifyCallWithFunc(call_var, inst_func, call_args, func_var);
272+
return does_fx;
273+
}
289274
},
290-
else => unreachable,
291-
};
292-
293-
// Unify instantiated argument types with actual arguments
294-
const inst_args = self.types.getFuncArgsSlice(func_struct.args);
295-
const arg_vars: []Var = @ptrCast(@alignCast(call_args));
296-
297-
// Only unify arguments if counts match - otherwise let the normal
298-
// unification process handle the arity mismatch error
299-
if (inst_args.len == arg_vars.len) {
300-
for (inst_args, arg_vars) |inst_arg, actual_arg| {
301-
_ = self.unify(inst_arg, actual_arg);
302-
}
303-
// The call's type is the instantiated return type
304-
_ = self.unify(call_var, func_struct.ret);
305-
} else {
306-
// Fall back to normal unification to get proper error message
307-
const func_content = self.types.mkFuncUnbound(arg_vars, call_var);
308-
const expected_func_var = self.types.freshFromContent(func_content);
309-
_ = self.unify(expected_func_var, func_var);
310-
}
311-
} else {
312-
// Fall back to the old behavior for non-annotated functions
313-
const arg_vars: []Var = @ptrCast(@alignCast(call_args));
314-
315-
// Create an unbound function type with the call result as return type
316-
// The unification will propagate the actual return type to the call
317-
//
318-
// TODO: Do we need to insert a CIR placeholder node here as well?
319-
// What happens if later this type variable has a problem, and we
320-
// try to look up it's region in CIR?
321-
const func_content = self.types.mkFuncUnbound(arg_vars, call_var);
322-
const expected_func_var = self.types.freshFromContent(func_content);
323-
_ = self.unify(expected_func_var, func_var);
275+
.fn_unbound => |_| {
276+
if (self.types.needsInstantiation(current_func_var)) {
277+
const instantiated_var = try instantiate.instantiateVar(self.types, current_func_var, self.gpa);
278+
const resolved_inst = self.types.resolveVar(instantiated_var);
279+
std.debug.assert(resolved_inst.desc.content == .structure);
280+
std.debug.assert(resolved_inst.desc.content.structure == .fn_unbound);
281+
const inst_func = resolved_inst.desc.content.structure.fn_unbound;
282+
try self.unifyCallWithFunc(call_var, inst_func, call_args, func_var);
283+
return does_fx;
284+
}
285+
},
286+
else => {
287+
// Non-function structure - fall through
288+
},
289+
},
290+
.alias => |alias| {
291+
// Resolve the alias, then continue on to the appropriate branch.
292+
// (It might be another alias, or we might be done and ready to proceed.)
293+
const backing_var = alias.getBackingVar(current_func_var);
294+
current_func_var = backing_var;
295+
current_content = self.types.resolveVar(backing_var).desc.content;
296+
continue :content_switch current_content;
297+
},
298+
else => {
299+
// Non-structure content - fall through
300+
},
324301
}
302+
303+
// We didn't handle the function call above (either because it wasn't a function
304+
// or it didn't need instantiation), so fall back on this logic.
305+
const arg_vars: []Var = @constCast(@ptrCast(@alignCast(call_args)));
306+
307+
// Create an unbound function type with the call result as return type
308+
// The unification will propagate the actual return type to the call
309+
//
310+
// TODO: Do we need to insert a CIR placeholder node here as well?
311+
// What happens if later this type variable has a problem, and we
312+
// try to look up its region in CIR?
313+
const func_content = self.types.mkFuncUnbound(arg_vars, call_var);
314+
const expected_func_var = self.types.freshFromContent(func_content);
315+
_ = self.unify(expected_func_var, current_func_var);
325316
}
326317
},
327318
.e_record => |e| {
@@ -515,6 +506,29 @@ pub fn checkExpr(self: *Self, expr_idx: CIR.Expr.Idx) std.mem.Allocator.Error!bo
515506
return does_fx;
516507
}
517508

509+
/// Helper function to unify a call expression with a function type
510+
fn unifyCallWithFunc(self: *Self, call_var: Var, func: types_mod.Func, call_args: []const CIR.Expr.Idx, original_func_var: Var) std.mem.Allocator.Error!void {
511+
// Unify instantiated argument types with actual arguments
512+
const inst_args = self.types.getFuncArgsSlice(func.args);
513+
const arg_vars: []Var = @constCast(@ptrCast(@alignCast(call_args)));
514+
515+
// Only unify arguments if counts match - otherwise let the normal
516+
// unification process handle the arity mismatch error
517+
if (inst_args.len == arg_vars.len) {
518+
for (inst_args, arg_vars) |inst_arg, actual_arg| {
519+
_ = self.unify(inst_arg, actual_arg);
520+
}
521+
// The call's type is the instantiated return type
522+
_ = self.unify(call_var, func.ret);
523+
} else {
524+
// Fall back to normal unification to get proper error message
525+
// Use the original func_var to avoid issues with instantiated variables in error reporting
526+
const func_content = self.types.mkFuncUnbound(arg_vars, call_var);
527+
const expected_func_var = self.types.freshFromContent(func_content);
528+
_ = self.unify(expected_func_var, original_func_var);
529+
}
530+
}
531+
518532
/// Check a lambda expression with an optional expected type
519533
fn checkLambdaWithExpected(self: *Self, expr_idx: CIR.Expr.Idx, lambda: anytype, expected_type: ?Var) std.mem.Allocator.Error!bool {
520534
const trace = tracy.trace(@src());

src/check/let_polymorphism_integration_test.zig

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ fn typeCheckExpr(allocator: std.mem.Allocator, source: []const u8) !struct {
5252
var canon_expr_idx: ?CIR.Expr.Idx = null;
5353
if (parse_ast.root_node_idx != 0) {
5454
const expr_idx: parse.AST.Expr.Idx = @enumFromInt(parse_ast.root_node_idx);
55-
canon_expr_idx = try can.canonicalize_expr(expr_idx);
55+
canon_expr_idx = try can.canonicalizeExpr(expr_idx);
5656
}
5757

5858
// Type check - continue even if there are parse errors
@@ -126,7 +126,7 @@ fn typeCheckFile(allocator: std.mem.Allocator, source: []const u8) !struct {
126126
};
127127
}
128128

129-
try can.canonicalize_file();
129+
try can.canonicalizeFile();
130130

131131
// Type check - continue even if there are parse errors
132132
const checker = try allocator.create(check_types);
@@ -187,7 +187,7 @@ fn typeCheckStatement(allocator: std.mem.Allocator, source: []const u8) !struct
187187
var canon_result: ?CIR.Expr.Idx = null;
188188
if (parse_ast.root_node_idx != 0) {
189189
const stmt_idx: parse.AST.Statement.Idx = @enumFromInt(parse_ast.root_node_idx);
190-
canon_result = try can.canonicalize_statement(stmt_idx);
190+
canon_result = try can.canonicalizeStatement(stmt_idx);
191191
}
192192

193193
// Type check - continue even if there are parse errors

0 commit comments

Comments
 (0)