@@ -185,7 +185,7 @@ pub enum ScopeParent {
185
185
186
186
// List of class names that a type refers to, after stripping Optional and Awaitable.
187
187
#[ derive( Debug , Clone , Serialize , PartialEq , Eq ) ]
188
- struct ClassNamesFromType {
188
+ pub struct ClassNamesFromType {
189
189
class_names : Vec < ClassRef > ,
190
190
#[ serde( skip_serializing_if = "<&bool>::not" ) ]
191
191
stripped_coroutine : bool ,
@@ -593,9 +593,9 @@ fn has_superclass(class: &Class, want: &Class, context: &ModuleContext) -> bool
593
593
}
594
594
595
595
impl ClassNamesFromType {
596
- fn from_class ( class : Class , context : & ModuleContext ) -> ClassNamesFromType {
596
+ pub fn from_class ( class : & Class , context : & ModuleContext ) -> ClassNamesFromType {
597
597
ClassNamesFromType {
598
- class_names : vec ! [ ClassRef :: from_class( & class, context. module_ids) ] ,
598
+ class_names : vec ! [ ClassRef :: from_class( class, context. module_ids) ] ,
599
599
stripped_coroutine : false ,
600
600
stripped_optional : false ,
601
601
stripped_readonly : false ,
@@ -604,7 +604,19 @@ impl ClassNamesFromType {
604
604
}
605
605
}
606
606
607
- fn not_a_class ( ) -> ClassNamesFromType {
607
+ #[ cfg( test) ]
608
+ pub fn from_classes ( class_names : Vec < ClassRef > , is_exhaustive : bool ) -> ClassNamesFromType {
609
+ ClassNamesFromType {
610
+ class_names,
611
+ stripped_coroutine : false ,
612
+ stripped_optional : false ,
613
+ stripped_readonly : false ,
614
+ unbound_type_variable : false ,
615
+ is_exhaustive,
616
+ }
617
+ }
618
+
619
+ pub fn not_a_class ( ) -> ClassNamesFromType {
608
620
ClassNamesFromType {
609
621
class_names : vec ! [ ] ,
610
622
stripped_coroutine : false ,
@@ -619,13 +631,13 @@ impl ClassNamesFromType {
619
631
self . class_names . is_empty ( )
620
632
}
621
633
622
- fn with_strip_optional ( mut self ) -> ClassNamesFromType {
623
- self . stripped_optional = true ;
634
+ pub fn with_strip_optional ( mut self , stripped_optional : bool ) -> ClassNamesFromType {
635
+ self . stripped_optional = stripped_optional ;
624
636
self
625
637
}
626
638
627
- fn with_strip_coroutine ( mut self ) -> ClassNamesFromType {
628
- self . stripped_coroutine = true ;
639
+ pub fn with_strip_coroutine ( mut self , stripped_coroutine : bool ) -> ClassNamesFromType {
640
+ self . stripped_coroutine = stripped_coroutine ;
629
641
self
630
642
}
631
643
@@ -701,22 +713,20 @@ fn is_scalar_type(get: &Type, want: &Class, context: &ModuleContext) -> bool {
701
713
702
714
fn get_classes_of_type ( type_ : & Type , context : & ModuleContext ) -> ClassNamesFromType {
703
715
if let Some ( inner) = strip_optional ( type_) {
704
- return get_classes_of_type ( inner, context) . with_strip_optional ( ) ;
716
+ return get_classes_of_type ( inner, context) . with_strip_optional ( true ) ;
705
717
}
706
718
if let Some ( inner) = strip_awaitable ( type_, context) {
707
- return get_classes_of_type ( inner, context) . with_strip_coroutine ( ) ;
719
+ return get_classes_of_type ( inner, context) . with_strip_coroutine ( true ) ;
708
720
}
709
721
if let Some ( inner) = strip_coroutine ( type_, context) {
710
- return get_classes_of_type ( inner, context) . with_strip_coroutine ( ) ;
722
+ return get_classes_of_type ( inner, context) . with_strip_coroutine ( true ) ;
711
723
}
712
724
// No need to strip ReadOnly[], it is already stripped by pyrefly.
713
725
match type_ {
714
726
Type :: ClassType ( class_type) => {
715
- ClassNamesFromType :: from_class ( class_type. class_object ( ) . clone ( ) , context)
716
- }
717
- Type :: Tuple ( _) => {
718
- ClassNamesFromType :: from_class ( context. stdlib . tuple_object ( ) . clone ( ) , context)
727
+ ClassNamesFromType :: from_class ( class_type. class_object ( ) , context)
719
728
}
729
+ Type :: Tuple ( _) => ClassNamesFromType :: from_class ( context. stdlib . tuple_object ( ) , context) ,
720
730
Type :: Union ( elements) if !elements. is_empty ( ) => elements
721
731
. iter ( )
722
732
. map ( |inner| get_classes_of_type ( inner, context) )
@@ -729,6 +739,42 @@ fn get_classes_of_type(type_: &Type, context: &ModuleContext) -> ClassNamesFromT
729
739
}
730
740
731
741
impl PysaType {
742
+ #[ cfg( test) ]
743
+ pub fn new ( string : String , class_names : ClassNamesFromType ) -> PysaType {
744
+ PysaType {
745
+ string,
746
+ is_bool : false ,
747
+ is_int : false ,
748
+ is_float : false ,
749
+ is_enum : false ,
750
+ class_names,
751
+ }
752
+ }
753
+
754
+ #[ cfg( test) ]
755
+ pub fn with_is_bool ( mut self , is_bool : bool ) -> PysaType {
756
+ self . is_bool = is_bool;
757
+ self
758
+ }
759
+
760
+ #[ cfg( test) ]
761
+ pub fn with_is_int ( mut self , is_int : bool ) -> PysaType {
762
+ self . is_int = is_int;
763
+ self
764
+ }
765
+
766
+ #[ cfg( test) ]
767
+ pub fn with_is_float ( mut self , is_float : bool ) -> PysaType {
768
+ self . is_float = is_float;
769
+ self
770
+ }
771
+
772
+ #[ cfg( test) ]
773
+ pub fn with_is_enum ( mut self , is_enum : bool ) -> PysaType {
774
+ self . is_enum = is_enum;
775
+ self
776
+ }
777
+
732
778
pub fn from_type ( type_ : & Type , context : & ModuleContext ) -> PysaType {
733
779
// Promote `Literal[..]` into `str` or `int`.
734
780
let type_ = type_. clone ( ) . promote_literals ( & context. stdlib ) ;
0 commit comments