From d94c20b433c160276e9223c8b1841f0582a8a4db Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Thu, 16 Oct 2025 22:03:42 +0900 Subject: [PATCH] fix --- pyrefly/lib/alt/class/class_field.rs | 172 +++++++++++++++++++++++++++ pyrefly/lib/test/descriptors.rs | 26 ++++ 2 files changed, 198 insertions(+) diff --git a/pyrefly/lib/alt/class/class_field.rs b/pyrefly/lib/alt/class/class_field.rs index be422efe05..95da80b129 100644 --- a/pyrefly/lib/alt/class/class_field.rs +++ b/pyrefly/lib/alt/class/class_field.rs @@ -31,6 +31,7 @@ use ruff_python_ast::name::Name; use ruff_text_size::TextRange; use starlark_map::small_map::SmallMap; use starlark_map::small_set::SmallSet; +use vec1::Vec1; use vec1::vec1; use crate::alt::answers::LookupAnswer; @@ -1264,6 +1265,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } else { value_ty }; + let descriptor_value_ty = value_ty.clone(); // Types provided in annotations shadow inferred types let ty = if let Some(ann) = annotation { @@ -1329,6 +1331,17 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ty }; + if descriptor.is_some() { + self.validate_descriptor_annotation( + class, + name, + annotation, + &descriptor_value_ty, + range, + errors, + ); + } + // Pin any vars in the type: leaking a var in a class field is particularly // likely to lead to data races where downstream uses can pin inconsistently. // @@ -1470,6 +1483,165 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { None } + fn collect_descriptor_classes_from_type(&self, out: &mut SmallSet, ty: &Type) { + match ty { + Type::ClassType(cls) => { + out.insert(cls.clone()); + } + Type::Union(types) => { + for ty in types { + self.collect_descriptor_classes_from_type(out, ty); + } + } + _ => {} + } + } + + fn descriptor_info_from_class(&self, cls: ClassType, range: TextRange) -> Option { + let getter = self + .get_class_member(cls.class_object(), &dunder::GET) + .is_some(); + let setter = self + .get_class_member(cls.class_object(), &dunder::SET) + .is_some(); + if getter || setter { + Some(Descriptor { + range, + cls, + getter, + setter, + }) + } else { + None + } + } + + fn validate_descriptor_annotation( + &self, + class: &Class, + name: &Name, + annotation: Option<&Annotation>, + value_ty: &Type, + range: TextRange, + errors: &ErrorCollector, + ) { + let Some(annotation) = annotation else { + return; + }; + let expected_ty = annotation.get_type().clone(); + if expected_ty.is_any() || expected_ty.is_error() { + return; + } + + let mut descriptor_classes = SmallSet::new(); + self.collect_descriptor_classes_from_type(&mut descriptor_classes, value_ty); + + if descriptor_classes.is_empty() { + return; + } + + let class_type = self.as_class_type_unchecked(class); + let mut messages = Vec::new(); + let ignore_errors = self.error_swallower(); + + for cls in descriptor_classes { + if let Type::ClassType(expected_cls) = &expected_ty + && expected_cls.class_object() == cls.class_object() + { + continue; + } + let Some(desc_info) = self.descriptor_info_from_class(cls.clone(), range) else { + continue; + }; + + if desc_info.getter { + if let Some(getter_method) = + self.resolve_descriptor_getter(&desc_info, &ignore_errors) + { + let instance_ret = self.call_descriptor_getter( + getter_method.clone(), + DescriptorBase::Instance(class_type.clone()), + range, + &ignore_errors, + None, + ); + if !self.is_subset_eq(&instance_ret, &expected_ty) { + let actual_display = self + .for_display(instance_ret.deterministic_printing()) + .deterministic_printing(); + let expected_display = self + .for_display(expected_ty.clone().deterministic_printing()) + .deterministic_printing(); + messages.push(format!( + "Descriptor `{}` returns `{}` when accessed on instances of `{}`, which is not assignable to `{}`", + name, + actual_display, + class.name(), + expected_display, + )); + } + + let class_ret = self.call_descriptor_getter( + getter_method, + DescriptorBase::ClassDef(class.dupe()), + range, + &ignore_errors, + None, + ); + if !self.is_subset_eq(&class_ret, &expected_ty) { + let actual_display = self + .for_display(class_ret.deterministic_printing()) + .deterministic_printing(); + let expected_display = self + .for_display(expected_ty.clone().deterministic_printing()) + .deterministic_printing(); + messages.push(format!( + "Descriptor `{}` returns `{}` when accessed on class `{}`, which is not assignable to `{}`", + name, + actual_display, + class.name(), + expected_display, + )); + } + } + } + + if desc_info.setter { + if let Some(setter_method) = + self.resolve_descriptor_setter(&desc_info, &ignore_errors) + { + let setter_check_errors = self.error_collector(); + self.call_descriptor_setter( + setter_method.clone(), + class_type.clone(), + CallArg::ty(&expected_ty, range), + range, + &setter_check_errors, + None, + ); + if !setter_check_errors.is_empty() { + let setter_display = self + .for_display(setter_method.deterministic_printing()) + .deterministic_printing(); + let expected_display = self + .for_display(expected_ty.clone().deterministic_printing()) + .deterministic_printing(); + messages.push(format!( + "Descriptor `{}` setter `{}` does not accept annotated type `{}`", + name, setter_display, expected_display, + )); + } + } + } + } + + if !messages.is_empty() { + if let Ok(msgs) = Vec1::try_from_vec(messages) { + errors.add(range, ErrorInfo::Kind(ErrorKind::BadAssignment), msgs); + } + } + } + /// Return (type of first inherited field, first inherited annotation). May not be from the same class! /// For example, in: /// class A: diff --git a/pyrefly/lib/test/descriptors.rs b/pyrefly/lib/test/descriptors.rs index 1847aaea52..849caf8bcb 100644 --- a/pyrefly/lib/test/descriptors.rs +++ b/pyrefly/lib/test/descriptors.rs @@ -319,3 +319,29 @@ class A: return self.d "#, ); + +testcase!( + test_descriptor_incompatible_get_return_annotation, + r#" +from typing import Literal +class A: + def __get__(self, obj, objtype) -> A | int: ... +class B(A): + def __get__(self, obj, objtype) -> Literal[1]: ... +class C: + a: A = B() # E: Descriptor `a` returns `Literal[1]` when accessed on instances of `C`, which is not assignable to `A` + "#, +); + +testcase!( + test_descriptor_incompatible_set_annotation, + r#" +from typing import Any +class A: + def __set__(self, obj, value: Any) -> None: ... +class B(A): + def __set__(self, obj, value: str) -> None: ... +class C: + a: A = B() # E: Descriptor `a` setter `BoundMethod[B, (self: B, obj: Unknown, value: str) -> None]` does not accept annotated type `A` + "#, +);