Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 2 additions & 8 deletions compiler-core/building/src/engine/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,8 @@ impl SnapshotGraph {
}

pub(crate) fn remove_edge(&mut self, to_id: SnapshotId) {
let keys: Vec<_> = self
.inner
.iter()
.filter_map(|(&from_id, &id)| if id == to_id { Some(from_id) } else { None })
.collect();
keys.iter().for_each(|key| {
self.inner.remove(key);
});
self.inner.retain(|_, &mut id| id != to_id);
self.inner.remove(&to_id);
}
}

Expand Down
129 changes: 122 additions & 7 deletions compiler-core/checking/src/check/convert.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::iter;

use itertools::Itertools;
use smol_str::SmolStr;

Expand Down Expand Up @@ -39,13 +41,13 @@ where
};
state.storage.intern(Type::Constructor(file_id, type_id))
}
lowering::TypeKind::Forall { bindings, type_ } => {
lowering::TypeKind::Forall { bindings, inner } => {
let binders = bindings
.iter()
.map(|binding| convert_forall_binding(state, context, binding))
.collect_vec();

let inner = type_.map_or(default, |id| type_to_core(state, context, id));
let inner = inner.map_or(default, |id| type_to_core(state, context, id));

let forall = binders
.into_iter()
Expand All @@ -65,7 +67,7 @@ where
lowering::TypeKind::String => default,
lowering::TypeKind::Variable { name, resolution } => {
let Some(resolution) = resolution else {
let name = name.clone().unwrap_or(INVALID_NAME);
let name = name.clone().unwrap_or(MISSING_NAME);
let kind = Variable::Free(name);
return state.storage.intern(Type::Variable(kind));
};
Expand All @@ -90,9 +92,122 @@ where
}
}

const INVALID_NAME: SmolStr = SmolStr::new_inline("<invalid>");
/// A variant of [`type_to_core`] for use with signature declarations.
///
/// Unlike the regular [`type_to_core`], this function does not call
/// [`CheckState::unbind`] after each [`lowering::TypeKind::Forall`]
/// node. This allows type variables to be scoped for the entire
/// declaration group rather than just the type signature.
pub fn signature_type_to_core<Q>(
state: &mut CheckState,
context: &CheckContext<Q>,
id: lowering::TypeId,
) -> TypeId
where
Q: ExternalQueries,
{
let default = context.prim.unknown;

let Some(kind) = context.lowered.info.get_type_kind(id) else {
return default;
};

match kind {
lowering::TypeKind::Forall { bindings, inner } => {
let binders = bindings
.iter()
.map(|binding| convert_forall_binding(state, context, binding))
.collect_vec();

let inner = inner.map_or(default, |id| type_to_core(state, context, id));

binders
.into_iter()
.rfold(inner, |inner, binder| state.storage.intern(Type::Forall(binder, inner)))
}

lowering::TypeKind::Parenthesized { parenthesized } => {
parenthesized.map(|id| signature_type_to_core(state, context, id)).unwrap_or(default)
}

_ => type_to_core(state, context, id),
}
}

pub struct InspectSignature {
pub variables: Vec<ForallBinder>,
pub arguments: Vec<TypeId>,
pub result: TypeId,
}

pub fn inspect_signature<Q>(
state: &mut CheckState,
context: &CheckContext<Q>,
id: lowering::TypeId,
) -> InspectSignature
where
Q: ExternalQueries,
{
let unknown = || {
let variables = [].into();
let arguments = [].into();
let result = context.prim.unknown;
InspectSignature { variables, arguments, result }
};

let Some(kind) = context.lowered.info.get_type_kind(id) else {
return unknown();
};

match kind {
lowering::TypeKind::Forall { bindings, inner } => {
let variables = bindings
.iter()
.map(|binding| convert_forall_binding(state, context, binding))
.collect();

let inner = inner.map_or(context.prim.unknown, |id| type_to_core(state, context, id));
let (arguments, result) = signature_components(state, inner);

InspectSignature { variables, arguments, result }
}

lowering::TypeKind::Parenthesized { parenthesized } => {
parenthesized.map(|id| inspect_signature(state, context, id)).unwrap_or_else(unknown)
}

_ => {
let variables = [].into();

let id = type_to_core(state, context, id);
let (arguments, result) = signature_components(state, id);

InspectSignature { variables, arguments, result }
}
}
}

fn signature_components(state: &mut CheckState, id: TypeId) -> (Vec<TypeId>, TypeId) {
let mut components = iter::successors(Some(id), |&id| match state.storage[id] {
Type::Function(_, id) => Some(id),
_ => None,
})
.map(|id| match state.storage[id] {
Type::Function(id, _) => id,
_ => id,
})
.collect_vec();

let Some(id) = components.pop() else {
unreachable!("invariant violated: expected non-empty components");
};

(components, id)
}

const MISSING_NAME: SmolStr = SmolStr::new_inline("<MissingName>");

fn convert_forall_binding<Q>(
pub fn convert_forall_binding<Q>(
state: &mut CheckState,
context: &CheckContext<Q>,
binding: &lowering::TypeVariableBinding,
Expand All @@ -101,11 +216,11 @@ where
Q: ExternalQueries,
{
let visible = binding.visible;
let name = binding.name.clone().unwrap_or(INVALID_NAME);
let name = binding.name.clone().unwrap_or(MISSING_NAME);

let kind = match binding.kind {
Some(id) => type_to_core(state, context, id),
None => state.fresh_unification(context),
None => state.fresh_unification_type(context),
};

let level = state.bind_forall(binding.id, kind);
Expand Down
25 changes: 18 additions & 7 deletions compiler-core/checking/src/check/kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,24 @@ where
lowering::TypeKind::Constructor { resolution } => {
let Some((file_id, type_id)) = *resolution else { return default };

let Ok(checked) = context.queries.checked(file_id) else { return default };
let Some(global_id) = checked.lookup_type(type_id) else { return default };
if file_id == context.id {
let t = convert::type_to_core(state, context, id);
if let Some(&k) = state.binding_group.types.get(&type_id) {
(t, k)
} else if let Some(&k) = state.checked.types.get(&type_id) {
(t, k)
} else {
default
}
} else {
let Ok(checked) = context.queries.checked(file_id) else { return default };
let Some(global_id) = checked.lookup_type(type_id) else { return default };

let t = convert::type_to_core(state, context, id);
let k = transfer::localize(state, context, global_id);
let t = convert::type_to_core(state, context, id);
let k = transfer::localize(state, context, global_id);

(t, k)
(t, k)
}
}

lowering::TypeKind::Forall { .. } => {
Expand Down Expand Up @@ -223,7 +234,7 @@ where

Type::Variable(ref variable) => match variable {
Variable::Implicit(_) => context.prim.unknown,
Variable::Skolem(_) => context.prim.unknown,
Variable::Skolem(_, kind) => *kind,
Variable::Bound(index) => {
let size = state.bound.size();

Expand Down Expand Up @@ -254,6 +265,6 @@ where
Q: ExternalQueries,
{
let (inferred_type, inferred_kind) = infer_surface_kind(state, context, id);
unification::unify(state, context, inferred_kind, kind);
unification::subsumes(state, context, inferred_kind, kind);
(inferred_type, inferred_kind)
}
48 changes: 46 additions & 2 deletions compiler-core/checking/src/check/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ use std::sync::Arc;

use building_types::QueryResult;
use files::FileId;
use indexing::IndexedModule;
use indexing::{IndexedModule, TermItemId, TypeItemId};
use itertools::Itertools;
use lowering::{GraphNodeId, ImplicitBindingId, LoweredModule, TypeVariableBindingId};
use rustc_hash::FxHashMap;

use crate::check::unification::UnificationContext;
use crate::check::{quantify, transfer};
use crate::core::{Type, TypeId, TypeInterner, debruijn};
use crate::{CheckedModule, ExternalQueries};

Expand All @@ -19,6 +22,7 @@ pub struct CheckState {
pub types: debruijn::BoundMap<TypeId>,

pub unification: UnificationContext,
pub binding_group: BindingGroupContext,
}

impl CheckState {
Expand Down Expand Up @@ -89,6 +93,37 @@ impl CheckState {
self.types.unbind(level);
self.kinds.unbind(level);
}

pub fn type_binding_group<Q>(
&mut self,
context: &CheckContext<Q>,
group: impl AsRef<[TypeItemId]>,
) where
Q: ExternalQueries,
{
for &item in group.as_ref() {
let t = self.fresh_unification_type(context);
self.binding_group.types.insert(item, t);
}
}

pub fn commit_binding_group<Q>(&mut self, context: &CheckContext<Q>)
where
Q: ExternalQueries,
{
for (item_id, type_id) in self.binding_group.terms.drain().collect_vec() {
if let Some(type_id) = quantify::quantify(self, type_id) {
let type_id = transfer::globalize(self, context, type_id);
self.checked.terms.insert(item_id, type_id);
}
}
for (item_id, type_id) in self.binding_group.types.drain().collect_vec() {
if let Some(type_id) = quantify::quantify(self, type_id) {
let type_id = transfer::globalize(self, context, type_id);
self.checked.types.insert(item_id, type_id);
}
}
}
}

pub struct CheckContext<'a, Q>
Expand All @@ -97,8 +132,11 @@ where
{
pub queries: &'a Q,
pub prim: PrimCore,

pub id: FileId,
pub indexed: Arc<IndexedModule>,
pub lowered: Arc<LoweredModule>,

pub prim_indexed: Arc<IndexedModule>,
}

Expand All @@ -116,7 +154,7 @@ where
let prim = PrimCore::collect(queries, state)?;
let prim_id = queries.prim_id();
let prim_indexed = queries.indexed(prim_id)?;
Ok(CheckContext { queries, prim, indexed, lowered, prim_indexed })
Ok(CheckContext { queries, prim, id, indexed, lowered, prim_indexed })
}
}

Expand Down Expand Up @@ -169,3 +207,9 @@ impl PrimCore {
})
}
}

#[derive(Default)]
pub struct BindingGroupContext {
pub terms: FxHashMap<TermItemId, TypeId>,
pub types: FxHashMap<TypeItemId, TypeId>,
}
Loading