Skip to content

Commit 3ab7c65

Browse files
committed
Type-check record field access
1 parent 683d95b commit 3ab7c65

File tree

177 files changed

+733
-1151
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

177 files changed

+733
-1151
lines changed

src/check/Check.zig

Lines changed: 176 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3249,12 +3249,35 @@ pub fn checkCIR2Expr(self: *Self, comptime CIR2: type, cir2: *const CIR2, expr_i
32493249

32503250
// Check all branches - they must all have the same type
32513251
var branch_type: ?Var = null;
3252-
while (iter.next()) |pattern_idx| {
3252+
while (iter.next()) |pattern_or_guard_idx| {
32533253
const body_idx = iter.next() orelse break;
32543254

3255-
// Check the pattern against the scrutinee type
3256-
const pattern_var = try self.checkCIR2Pattern(CIR2, cir2, @enumFromInt(@intFromEnum(pattern_idx)));
3257-
_ = try self.unify(pattern_var, scrutinee_var);
3255+
// Check if this is a guarded pattern (if_without_else node)
3256+
const pattern_node = cir2.getNode(@enumFromInt(@intFromEnum(pattern_or_guard_idx)));
3257+
3258+
if (pattern_node.tag == .if_without_else) {
3259+
// This is a guarded pattern: pattern if guard_condition
3260+
// The if_without_else node contains [pattern, guard_condition]
3261+
const branches_idx = @as(collections.NodeSlices(AST2.Node.Idx).Idx, @enumFromInt(pattern_node.payload.if_branches));
3262+
var guard_iter = cir2.ast.node_slices.nodes(&branches_idx);
3263+
3264+
// Extract pattern and guard
3265+
const pattern_idx = guard_iter.next() orelse break;
3266+
const guard_idx = guard_iter.next() orelse break;
3267+
3268+
// Check the pattern against the scrutinee type
3269+
const pattern_var = try self.checkCIR2Pattern(CIR2, cir2, @enumFromInt(@intFromEnum(pattern_idx)));
3270+
_ = try self.unify(pattern_var, scrutinee_var);
3271+
3272+
// Check the guard condition - must be Bool
3273+
const guard_var = try self.checkCIR2Expr(CIR2, cir2, @enumFromInt(@intFromEnum(guard_idx)));
3274+
// For now, just ensure guard type exists - proper Bool type checking coming later
3275+
_ = guard_var;
3276+
} else {
3277+
// Regular pattern without guard
3278+
const pattern_var = try self.checkCIR2Pattern(CIR2, cir2, @enumFromInt(@intFromEnum(pattern_or_guard_idx)));
3279+
_ = try self.unify(pattern_var, scrutinee_var);
3280+
}
32583281

32593282
// Check the body
32603283
const body_var = try self.checkCIR2Expr(CIR2, cir2, @enumFromInt(@intFromEnum(body_idx)));
@@ -3341,6 +3364,24 @@ fn checkCIR2Pattern(
33413364
cir2: *const CIR2,
33423365
patt_idx: CIR2.Patt.Idx,
33433366
) !Var {
3367+
const patt_node = cir2.getNode(patt_idx.toNodeIdx());
3368+
3369+
// Handle pattern alternatives (binop_or for pattern1 | pattern2)
3370+
if (patt_node.tag == .binop_or) {
3371+
const binop = cir2.ast.node_slices.binOp(patt_node.payload.binop);
3372+
3373+
// Check both alternatives against the same type
3374+
const lhs_var = try self.checkCIR2Pattern(CIR2, cir2, @enumFromInt(@intFromEnum(binop.lhs)));
3375+
const rhs_var = try self.checkCIR2Pattern(CIR2, cir2, @enumFromInt(@intFromEnum(binop.rhs)));
3376+
3377+
// Both alternatives must have the same type
3378+
_ = try self.unify(lhs_var, rhs_var);
3379+
3380+
const patt_var = @as(Var, @enumFromInt(@intFromEnum(patt_idx.toNodeIdx())));
3381+
_ = try self.unify(patt_var, lhs_var);
3382+
return patt_var;
3383+
}
3384+
33443385
const patt = cir2.getPatt(patt_idx);
33453386

33463387
// Get the type variable for this pattern (should already exist from canonicalization)
@@ -3389,19 +3430,54 @@ fn checkCIR2Pattern(
33893430
// Tag pattern with optional payload
33903431
var iter = cir2.ast.node_slices.nodes(&patt.payload.nodes);
33913432

3392-
// Skip tag constructor
3393-
_ = iter.next();
3433+
// First element is the tag constructor
3434+
const tag_idx = iter.next();
3435+
3436+
// Collect payload pattern types
3437+
var payload_types = std.ArrayList(Var).init(self.gpa);
3438+
defer payload_types.deinit();
33943439

3395-
// Check payload pattern if present
3396-
if (iter.next()) |payload_idx| {
3397-
// Payload should be a pattern
3398-
const payload_patt = cir2.getPatt(@enumFromInt(@intFromEnum(payload_idx)));
3399-
if (payload_patt.tag != .malformed) {
3400-
_ = try self.checkCIR2Pattern(CIR2, cir2, @enumFromInt(@intFromEnum(payload_idx)));
3440+
while (iter.next()) |payload_idx| {
3441+
// Check nested patterns in payload
3442+
const payload_var = try self.checkCIR2Pattern(CIR2, cir2, @enumFromInt(@intFromEnum(payload_idx)));
3443+
try payload_types.append(payload_var);
3444+
}
3445+
3446+
// Create tag union type with this tag
3447+
if (tag_idx) |tag_node_idx| {
3448+
const tag_node = cir2.getNode(@enumFromInt(@intFromEnum(tag_node_idx)));
3449+
if (tag_node.tag == .uc) {
3450+
// Get tag name
3451+
const tag_name = tag_node.payload.ident;
3452+
3453+
// Create a tag with payload types
3454+
const payload_range = if (payload_types.items.len > 0)
3455+
try self.types.appendVars(payload_types.items)
3456+
else
3457+
types_mod.Var.SafeList.Range.empty();
3458+
3459+
const tag = types_mod.Tag{ .name = tag_name, .args = payload_range };
3460+
const tag_idx_new = try self.types.tags.append(self.gpa, tag);
3461+
3462+
// Create tag union with this single tag
3463+
const tag_range = types_mod.Tag.SafeMultiList.Range{
3464+
.start = tag_idx_new,
3465+
.count = 1,
3466+
};
3467+
3468+
// Extension is unbound for now
3469+
const ext_var = try self.types.fresh();
3470+
3471+
const tag_union_content = Content{ .structure = .{ .tag_union = .{
3472+
.tags = tag_range,
3473+
.ext = ext_var,
3474+
} } };
3475+
3476+
const tag_union_var = try self.types.freshFromContent(tag_union_content);
3477+
_ = try self.unify(patt_var, tag_union_var);
34013478
}
34023479
}
34033480

3404-
// Pattern has tag union type
34053481
return patt_var;
34063482
},
34073483

@@ -3413,28 +3489,43 @@ fn checkCIR2Pattern(
34133489

34143490
var iter = cir2.ast.node_slices.nodes(&patt.payload.nodes);
34153491
while (iter.next()) |elem_idx| {
3416-
// Elements should be patterns
3417-
const elem_patt = cir2.getPatt(@enumFromInt(@intFromEnum(elem_idx)));
3418-
if (elem_patt.tag != .malformed) {
3419-
const elem_var = try self.checkCIR2Pattern(CIR2, cir2, @enumFromInt(@intFromEnum(elem_idx)));
3420-
try elem_vars.append(elem_var);
3421-
}
3492+
// Recursively check nested patterns
3493+
const elem_var = try self.checkCIR2Pattern(CIR2, cir2, @enumFromInt(@intFromEnum(elem_idx)));
3494+
try elem_vars.append(elem_var);
34223495
}
34233496

34243497
// Create tuple type from element types
3425-
// For now, just use a fresh variable
3498+
if (elem_vars.items.len > 0) {
3499+
const elems_range = try self.types.appendVars(elem_vars.items);
3500+
const tuple_content = Content{ .structure = .{ .tuple = .{ .elems = elems_range } } };
3501+
const tuple_var = try self.types.freshFromContent(tuple_content);
3502+
_ = try self.unify(patt_var, tuple_var);
3503+
}
3504+
34263505
return patt_var;
34273506
},
34283507

34293508
.list_destructure => {
3430-
// List pattern - all elements must have same type
3509+
// List pattern - all non-rest elements must have same type
34313510
var elem_type: ?Var = null;
3511+
var has_rest = false;
34323512

34333513
var iter = cir2.ast.node_slices.nodes(&patt.payload.nodes);
34343514
while (iter.next()) |elem_idx| {
3435-
// Elements should be patterns
34363515
const elem_patt = cir2.getPatt(@enumFromInt(@intFromEnum(elem_idx)));
3437-
if (elem_patt.tag != .malformed) {
3516+
3517+
if (elem_patt.tag == .double_dot_ident) {
3518+
// Rest pattern in list: [first, ..rest]
3519+
has_rest = true;
3520+
// Rest pattern binds to a list of the element type
3521+
if (elem_type) |et| {
3522+
const rest_list_content = Content{ .structure = .{ .list = et } };
3523+
const rest_list_var = try self.types.freshFromContent(rest_list_content);
3524+
const rest_var = try self.checkCIR2Pattern(CIR2, cir2, @enumFromInt(@intFromEnum(elem_idx)));
3525+
_ = try self.unify(rest_var, rest_list_var);
3526+
}
3527+
} else if (elem_patt.tag != .malformed) {
3528+
// Regular element pattern
34383529
const elem_var = try self.checkCIR2Pattern(CIR2, cir2, @enumFromInt(@intFromEnum(elem_idx)));
34393530

34403531
if (elem_type) |expected| {
@@ -3450,17 +3541,28 @@ fn checkCIR2Pattern(
34503541
const list_content = Content{ .structure = .{ .list = et } };
34513542
const list_var = try self.types.freshFromContent(list_content);
34523543
_ = try self.unify(patt_var, list_var);
3544+
} else if (!has_rest) {
3545+
// Empty list pattern []
3546+
const unbound_elem = try self.types.fresh();
3547+
const list_content = Content{ .structure = .{ .list = unbound_elem } };
3548+
const list_var = try self.types.freshFromContent(list_content);
3549+
_ = try self.unify(patt_var, list_var);
34533550
}
34543551

34553552
return patt_var;
34563553
},
34573554

34583555
.record_destructure => {
3459-
// Record pattern - check field patterns
3460-
// Fields can be binop_colon (field: pattern) or direct patterns
3556+
// Record pattern - check field patterns and build record type
3557+
var fields = std.ArrayList(types_mod.RecordField).init(self.gpa);
3558+
defer fields.deinit();
3559+
3560+
var has_rest = false;
3561+
var rest_var: ?Var = null;
3562+
34613563
var iter = cir2.ast.node_slices.nodes(&patt.payload.nodes);
34623564
while (iter.next()) |field_idx| {
3463-
const field_node = cir2.getNode(field_idx);
3565+
const field_node = cir2.getNode(@enumFromInt(@intFromEnum(field_idx)));
34643566
const tag_value = @as(u8, @intFromEnum(field_node.tag));
34653567

34663568
// Check if it's an expression (binop_colon for field: pattern)
@@ -3469,16 +3571,59 @@ fn checkCIR2Pattern(
34693571
const expr_view = cir2.getExpr(@enumFromInt(@intFromEnum(field_idx)));
34703572
if (expr_view.tag == .binop_colon) {
34713573
const binop = cir2.getBinOp(CIR2.Patt.Idx, expr_view.payload.binop);
3472-
// RHS is the pattern
3473-
_ = try self.checkCIR2Pattern(CIR2, cir2, binop.rhs);
3574+
3575+
// LHS is field name, RHS is the pattern
3576+
const name_node = cir2.getNode(@enumFromInt(@intFromEnum(binop.lhs)));
3577+
if (name_node.tag == .lc) {
3578+
const field_name = name_node.payload.ident;
3579+
const field_type = try self.checkCIR2Pattern(CIR2, cir2, binop.rhs);
3580+
3581+
try fields.append(.{
3582+
.name = field_name,
3583+
.var_ = field_type,
3584+
});
3585+
}
34743586
}
34753587
} else if (tag_value >= CIR2.PATT_TAG_START) {
3476-
// It's a pattern node - check it directly
3477-
_ = try self.checkCIR2Pattern(CIR2, cir2, @enumFromInt(@intFromEnum(field_idx)));
3588+
// It's a pattern node - could be shorthand field or rest pattern
3589+
const field_patt = cir2.getPatt(@enumFromInt(@intFromEnum(field_idx)));
3590+
3591+
if (field_patt.tag == .ident) {
3592+
// Shorthand field pattern: { x } means { x: x }
3593+
const field_name = field_patt.payload.ident;
3594+
const field_type = try self.checkCIR2Pattern(CIR2, cir2, @enumFromInt(@intFromEnum(field_idx)));
3595+
3596+
try fields.append(.{
3597+
.name = field_name,
3598+
.var_ = field_type,
3599+
});
3600+
} else if (field_patt.tag == .double_dot_ident) {
3601+
// Rest pattern: { ..rest }
3602+
has_rest = true;
3603+
rest_var = try self.checkCIR2Pattern(CIR2, cir2, @enumFromInt(@intFromEnum(field_idx)));
3604+
} else {
3605+
// Other pattern types in record context
3606+
_ = try self.checkCIR2Pattern(CIR2, cir2, @enumFromInt(@intFromEnum(field_idx)));
3607+
}
34783608
}
34793609
}
34803610

3481-
// Record type - for now use fresh variable
3611+
// Create record type with collected fields
3612+
if (fields.items.len > 0 or has_rest) {
3613+
const record_fields = if (fields.items.len > 0)
3614+
try self.types.appendRecordFields(fields.items)
3615+
else
3616+
types_mod.RecordField.SafeMultiList.Range.empty();
3617+
3618+
// Extension is the rest pattern variable or unbound
3619+
const record_ext = rest_var orelse try self.types.fresh();
3620+
3621+
const record_content = Content{ .structure = .{ .record = .{ .fields = record_fields, .ext = record_ext } } };
3622+
3623+
const record_var = try self.types.freshFromContent(record_content);
3624+
_ = try self.unify(patt_var, record_var);
3625+
}
3626+
34823627
return patt_var;
34833628
},
34843629

src/parse/AST2.zig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ pub fn forLoop(self: *const Ast, idx: Node.Idx) ForLoop {
220220
}
221221

222222
/// Get an iterator for lambda args
223-
fn lambdaArgs(self: *const Ast, lambda_val: Lambda) LambdaArgsIterator {
223+
pub fn lambdaArgs(self: *const Ast, lambda_val: Lambda) LambdaArgsIterator {
224224
return LambdaArgsIterator{
225225
.iter = self.node_slices.nodes(&lambda_val.args_idx),
226226
.skipped_body = false,

src/parse/Parser2.zig

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2159,7 +2159,8 @@ fn processState(self: *Parser, state: ParseState) !StateAction {
21592159
// Parse requires signatures as record expressions
21602160
// The signatures are represented as record syntax during parsing
21612161
if (self.peek() == .OpenCurly) {
2162-
requires_signatures = try self.parseRecordExpr();
2162+
// Parse record literal for requires signatures
2163+
requires_signatures = try self.parseBlockOrRecord();
21632164
}
21642165
}
21652166

src/types/TypeWriter.zig

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -492,14 +492,14 @@ fn writeRecordFields(self: *TypeWriter, fields: RecordField.SafeMultiList.Range,
492492

493493
// Write first field - we already verified that there's at least one field
494494
_ = try self.buf.writer().write(self.getIdent(fields_slice.items(.name)[0]));
495-
_ = try self.buf.writer().write(":");
495+
_ = try self.buf.writer().write(": ");
496496
try self.writeVarWithContext(fields_slice.items(.var_)[0], .RecordFieldContent, root_var);
497497

498498
// Write remaining fields
499499
for (fields_slice.items(.name)[1..], fields_slice.items(.var_)[1..]) |name, var_| {
500500
_ = try self.buf.writer().write(", ");
501501
_ = try self.buf.writer().write(self.getIdent(name));
502-
_ = try self.buf.writer().write(":");
502+
_ = try self.buf.writer().write(": ");
503503
try self.writeVarWithContext(var_, .RecordFieldContent, root_var);
504504
}
505505

@@ -535,7 +535,7 @@ fn writeRecord(self: *TypeWriter, record: Record, root_var: Var) std.mem.Allocat
535535
for (fields.items(.name), fields.items(.var_), 0..) |field_name, field_var, i| {
536536
if (i > 0) _ = try self.buf.writer().write(", ");
537537
_ = try self.buf.writer().write(self.getIdent(field_name));
538-
_ = try self.buf.writer().write(":");
538+
_ = try self.buf.writer().write(": ");
539539
try self.writeVarWithContext(field_var, .RecordFieldContent, root_var);
540540
}
541541

@@ -550,7 +550,7 @@ fn writeRecord(self: *TypeWriter, record: Record, root_var: Var) std.mem.Allocat
550550
for (ext_fields.items(.name), ext_fields.items(.var_)) |field_name, field_var| {
551551
if (fields.len > 0 or ext_fields.len > 0) _ = try self.buf.writer().write(", ");
552552
_ = try self.buf.writer().write(self.getIdent(field_name));
553-
_ = try self.buf.writer().write(":");
553+
_ = try self.buf.writer().write(": ");
554554
try self.writeVarWithContext(field_var, .RecordFieldContent, root_var);
555555
}
556556
// Recursively handle the extension's extension
@@ -596,7 +596,7 @@ fn writeRecordExtension(self: *TypeWriter, ext_var: Var, num_fields: usize, root
596596
for (ext_fields.items(.name), ext_fields.items(.var_)) |field_name, field_var| {
597597
_ = try self.buf.writer().write(", ");
598598
_ = try self.buf.writer().write(self.getIdent(field_name));
599-
_ = try self.buf.writer().write(":");
599+
_ = try self.buf.writer().write(": ");
600600
try self.writeVarWithContext(field_var, .RecordFieldContent, root_var);
601601
}
602602
// Recursively handle the extension's extension

0 commit comments

Comments
 (0)