11from __future__ import annotations
22
3- from typing import Final , Iterable , Mapping , Sequence , TypeVar , cast , overload
3+ from contextlib import contextmanager
4+ from typing import Final , Generator , Iterable , Mapping , Sequence , TypeVar , cast , overload
45
5- from mypy .nodes import ARG_STAR , FakeInfo , Var
6+ from mypy .nodes import ARG_STAR , CONTRAVARIANT , COVARIANT , FakeInfo , Var
67from mypy .state import state
78from mypy .types import (
89 ANY_STRATEGY ,
3839 UninhabitedType ,
3940 UnionType ,
4041 UnpackType ,
42+ VarianceModifier ,
4143 flatten_nested_unions ,
4244 get_proper_type ,
4345 split_with_prefix_and_suffix ,
5355
5456
5557@overload
56- def expand_type (typ : CallableType , env : Mapping [TypeVarId , Type ]) -> CallableType : ...
58+ def expand_type (
59+ typ : CallableType , env : Mapping [TypeVarId , Type ], * , variance : int | None = ...
60+ ) -> CallableType : ...
5761
5862
5963@overload
60- def expand_type (typ : ProperType , env : Mapping [TypeVarId , Type ]) -> ProperType : ...
64+ def expand_type (
65+ typ : ProperType , env : Mapping [TypeVarId , Type ], * , variance : int | None = ...
66+ ) -> ProperType : ...
6167
6268
6369@overload
64- def expand_type (typ : Type , env : Mapping [TypeVarId , Type ]) -> Type : ...
70+ def expand_type (
71+ typ : Type , env : Mapping [TypeVarId , Type ], * , variance : int | None = ...
72+ ) -> Type : ...
6573
6674
67- def expand_type (typ : Type , env : Mapping [TypeVarId , Type ]) -> Type :
75+ def expand_type (typ : Type , env : Mapping [TypeVarId , Type ], * , variance = None ) -> Type :
6876 """Substitute any type variable references in a type given by a type
6977 environment.
7078 """
71- return typ .accept (ExpandTypeVisitor (env ))
79+ return typ .accept (ExpandTypeVisitor (env , variance = variance ))
7280
7381
7482@overload
75- def expand_type_by_instance (typ : CallableType , instance : Instance ) -> CallableType : ...
83+ def expand_type_by_instance (
84+ typ : CallableType , instance : Instance , * , use_variance : int | None = ...
85+ ) -> CallableType : ...
7686
7787
7888@overload
79- def expand_type_by_instance (typ : ProperType , instance : Instance ) -> ProperType : ...
89+ def expand_type_by_instance (
90+ typ : ProperType , instance : Instance , * , use_variance : int | None = ...
91+ ) -> ProperType : ...
8092
8193
8294@overload
83- def expand_type_by_instance (typ : Type , instance : Instance ) -> Type : ...
95+ def expand_type_by_instance (
96+ typ : Type , instance : Instance , * , use_variance : int | None = ...
97+ ) -> Type : ...
8498
8599
86- def expand_type_by_instance (typ : Type , instance : Instance ) -> Type :
100+ def expand_type_by_instance (typ : Type , instance : Instance , use_variance = None ) -> Type :
87101 """Substitute type variables in type using values from an Instance.
88102 Type variables are considered to be bound by the class declaration."""
89103 if not instance .args and not instance .type .has_type_var_tuple_type :
@@ -108,12 +122,11 @@ def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
108122 else :
109123 tvars = tuple (instance .type .defn .type_vars )
110124 instance_args = instance .args
111-
112125 for binder , arg in zip (tvars , instance_args ):
113126 assert isinstance (binder , TypeVarLikeType )
114127 variables [binder .id ] = arg
115128
116- return expand_type (typ , variables )
129+ return expand_type (typ , variables , variance = use_variance )
117130
118131
119132F = TypeVar ("F" , bound = FunctionLike )
@@ -181,10 +194,28 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):
181194
182195 variables : Mapping [TypeVarId , Type ] # TypeVar id -> TypeVar value
183196
184- def __init__ (self , variables : Mapping [TypeVarId , Type ]) -> None :
197+ def __init__ (
198+ self , variables : Mapping [TypeVarId , Type ], * , variance : int | None = None
199+ ) -> None :
185200 super ().__init__ ()
186201 self .variables = variables
187202 self .recursive_tvar_guard : dict [TypeVarId , Type | None ] = {}
203+ self .variance = variance
204+ self .using_variance : int | None = None
205+
206+ @contextmanager
207+ def in_variance (self ) -> Generator [None ]:
208+ using_variance = self .using_variance
209+ self .using_variance = CONTRAVARIANT
210+ yield
211+ self .using_variance = using_variance
212+
213+ @contextmanager
214+ def out_variance (self ) -> Generator [None ]:
215+ using_variance = self .using_variance
216+ self .using_variance = COVARIANT
217+ yield
218+ self .using_variance = using_variance
188219
189220 def visit_unbound_type (self , t : UnboundType ) -> Type :
190221 return t
@@ -238,6 +269,18 @@ def visit_type_var(self, t: TypeVarType) -> Type:
238269 if t .id .is_self ():
239270 t = t .copy_modified (upper_bound = t .upper_bound .accept (self ))
240271 repl = self .variables .get (t .id , t )
272+ use_site_variance = repl .variance if isinstance (repl , VarianceModifier ) else None
273+ positional_variance = self .using_variance or self .variance
274+ if (
275+ positional_variance is not None
276+ and use_site_variance is not None
277+ and positional_variance != use_site_variance
278+ ):
279+ repl = (
280+ t .upper_bound .accept (self )
281+ if positional_variance == COVARIANT
282+ else UninhabitedType ()
283+ )
241284 if isinstance (repl , ProperType ) and isinstance (repl , Instance ):
242285 # TODO: do we really need to do this?
243286 # If I try to remove this special-casing ~40 tests fail on reveal_type().
@@ -414,10 +457,15 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
414457 needs_normalization = True
415458 arg_types = self .interpolate_args_for_unpack (t , var_arg .typ )
416459 else :
417- arg_types = self .expand_types (t .arg_types )
460+ with self .in_variance ():
461+ arg_types = self .expand_types (t .arg_types )
462+ with self .out_variance ():
463+ ret_type = t .ret_type .accept (self )
464+ if isinstance (ret_type , VarianceModifier ):
465+ ret_type = ret_type .value
418466 expanded = t .copy_modified (
419467 arg_types = arg_types ,
420- ret_type = t . ret_type . accept ( self ) ,
468+ ret_type = ret_type ,
421469 type_guard = t .type_guard and cast (TypeGuardType , t .type_guard .accept (self )),
422470 type_is = (t .type_is .accept (self ) if t .type_is is not None else None ),
423471 )
@@ -538,7 +586,10 @@ def visit_typeguard_type(self, t: TypeGuardType) -> Type:
538586 def expand_types (self , types : Iterable [Type ]) -> list [Type ]:
539587 a : list [Type ] = []
540588 for t in types :
541- a .append (t .accept (self ))
589+ typ = t .accept (self )
590+ if isinstance (typ , VarianceModifier ):
591+ typ = typ .value
592+ a .append (typ )
542593 return a
543594
544595
0 commit comments