Skip to content

Commit 201b3fb

Browse files
committed
variance modifiers
1 parent f24781b commit 201b3fb

17 files changed

+330
-41
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## [Unreleased]
44
### Added
55
- `collections.User*` should have `__repr__`
6+
- explicit and use-site variance modifiers `In`/`Out`/`InOut`
67

78
## [2.8.1]
89
### Fixes

docs/source/based_features.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ Using the ``&`` operator or ``basedtyping.Intersection`` you can denote intersec
2525
x.reset()
2626
x.add("first")
2727
28+
Explicit and Use-Site variance
29+
------------------------------
30+
31+
it is frequently desirable to explicitly declare the variance of type paramters on types and classes
32+
33+
2834
Type Joins
2935
----------
3036

mypy/checkmember.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
ARG_POS,
1818
ARG_STAR,
1919
ARG_STAR2,
20+
CONTRAVARIANT,
21+
COVARIANT,
2022
EXCLUDED_ENUM_ATTRIBUTES,
2123
SYMBOL_FUNCBASE_TYPES,
2224
Context,
@@ -811,7 +813,9 @@ def analyze_var(
811813
mx.msg.cant_assign_to_classvar(name, mx.context)
812814
t = freshen_all_functions_type_vars(typ)
813815
t = expand_self_type_if_needed(t, mx, var, original_itype)
814-
t = expand_type_by_instance(t, itype)
816+
t = expand_type_by_instance(
817+
t, itype, use_variance=CONTRAVARIANT if mx.is_lvalue else COVARIANT
818+
)
815819
freeze_all_type_vars(t)
816820
result = t
817821
typ = get_proper_type(typ)

mypy/expandtype.py

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

3-
from typing import Final, Iterable, Mapping, Sequence, TypeVar, cast, overload
3+
from contextlib import contextmanager
4+
from typing import Final, Generator, Iterable, Mapping, Sequence, TypeVar, cast, overload
45

5-
from mypy.nodes import ARG_STAR, FakeInfo, Var
6+
from mypy.nodes import ARG_STAR, CONTRAVARIANT, COVARIANT, FakeInfo, Var
67
from mypy.state import state
78
from mypy.types import (
89
ANY_STRATEGY,
@@ -38,6 +39,7 @@
3839
UninhabitedType,
3940
UnionType,
4041
UnpackType,
42+
VarianceModifier,
4143
flatten_nested_unions,
4244
get_proper_type,
4345
split_with_prefix_and_suffix,
@@ -53,37 +55,49 @@
5355

5456

5557
@overload
56-
def expand_type(typ: CallableType, env: Mapping[TypeVarId, Type]) -> CallableType: ...
58+
def expand_type(
59+
typ: CallableType, env: Mapping[TypeVarId, Type], *, variance: int | None = ...
60+
) -> CallableType: ...
5761

5862

5963
@overload
60-
def expand_type(typ: ProperType, env: Mapping[TypeVarId, Type]) -> ProperType: ...
64+
def expand_type(
65+
typ: ProperType, env: Mapping[TypeVarId, Type], *, variance: int | None = ...
66+
) -> ProperType: ...
6167

6268

6369
@overload
64-
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type: ...
70+
def expand_type(
71+
typ: Type, env: Mapping[TypeVarId, Type], *, variance: int | None = ...
72+
) -> Type: ...
6573

6674

67-
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type:
75+
def expand_type(typ: Type, env: Mapping[TypeVarId, Type], *, variance=None) -> Type:
6876
"""Substitute any type variable references in a type given by a type
6977
environment.
7078
"""
71-
return typ.accept(ExpandTypeVisitor(env))
79+
return typ.accept(ExpandTypeVisitor(env, variance=variance))
7280

7381

7482
@overload
75-
def expand_type_by_instance(typ: CallableType, instance: Instance) -> CallableType: ...
83+
def expand_type_by_instance(
84+
typ: CallableType, instance: Instance, *, use_variance: int | None = ...
85+
) -> CallableType: ...
7686

7787

7888
@overload
79-
def expand_type_by_instance(typ: ProperType, instance: Instance) -> ProperType: ...
89+
def expand_type_by_instance(
90+
typ: ProperType, instance: Instance, *, use_variance: int | None = ...
91+
) -> ProperType: ...
8092

8193

8294
@overload
83-
def expand_type_by_instance(typ: Type, instance: Instance) -> Type: ...
95+
def expand_type_by_instance(
96+
typ: Type, instance: Instance, *, use_variance: int | None = ...
97+
) -> Type: ...
8498

8599

86-
def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
100+
def expand_type_by_instance(typ: Type, instance: Instance, use_variance=None) -> Type:
87101
"""Substitute type variables in type using values from an Instance.
88102
Type variables are considered to be bound by the class declaration."""
89103
if not instance.args and not instance.type.has_type_var_tuple_type:
@@ -108,12 +122,11 @@ def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
108122
else:
109123
tvars = tuple(instance.type.defn.type_vars)
110124
instance_args = instance.args
111-
112125
for binder, arg in zip(tvars, instance_args):
113126
assert isinstance(binder, TypeVarLikeType)
114127
variables[binder.id] = arg
115128

116-
return expand_type(typ, variables)
129+
return expand_type(typ, variables, variance=use_variance)
117130

118131

119132
F = TypeVar("F", bound=FunctionLike)
@@ -181,10 +194,28 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):
181194

182195
variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value
183196

184-
def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
197+
def __init__(
198+
self, variables: Mapping[TypeVarId, Type], *, variance: int | None = None
199+
) -> None:
185200
super().__init__()
186201
self.variables = variables
187202
self.recursive_tvar_guard: dict[TypeVarId, Type | None] = {}
203+
self.variance = variance
204+
self.using_variance: int | None = None
205+
206+
@contextmanager
207+
def in_variance(self) -> Generator[None]:
208+
using_variance = self.using_variance
209+
self.using_variance = CONTRAVARIANT
210+
yield
211+
self.using_variance = using_variance
212+
213+
@contextmanager
214+
def out_variance(self) -> Generator[None]:
215+
using_variance = self.using_variance
216+
self.using_variance = COVARIANT
217+
yield
218+
self.using_variance = using_variance
188219

189220
def visit_unbound_type(self, t: UnboundType) -> Type:
190221
return t
@@ -238,6 +269,18 @@ def visit_type_var(self, t: TypeVarType) -> Type:
238269
if t.id.is_self():
239270
t = t.copy_modified(upper_bound=t.upper_bound.accept(self))
240271
repl = self.variables.get(t.id, t)
272+
use_site_variance = repl.variance if isinstance(repl, VarianceModifier) else None
273+
positional_variance = self.using_variance or self.variance
274+
if (
275+
positional_variance is not None
276+
and use_site_variance is not None
277+
and positional_variance != use_site_variance
278+
):
279+
repl = (
280+
t.upper_bound.accept(self)
281+
if positional_variance == COVARIANT
282+
else UninhabitedType()
283+
)
241284
if isinstance(repl, ProperType) and isinstance(repl, Instance):
242285
# TODO: do we really need to do this?
243286
# If I try to remove this special-casing ~40 tests fail on reveal_type().
@@ -414,10 +457,15 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
414457
needs_normalization = True
415458
arg_types = self.interpolate_args_for_unpack(t, var_arg.typ)
416459
else:
417-
arg_types = self.expand_types(t.arg_types)
460+
with self.in_variance():
461+
arg_types = self.expand_types(t.arg_types)
462+
with self.out_variance():
463+
ret_type = t.ret_type.accept(self)
464+
if isinstance(ret_type, VarianceModifier):
465+
ret_type = ret_type.value
418466
expanded = t.copy_modified(
419467
arg_types=arg_types,
420-
ret_type=t.ret_type.accept(self),
468+
ret_type=ret_type,
421469
type_guard=t.type_guard and cast(TypeGuardType, t.type_guard.accept(self)),
422470
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
423471
)
@@ -538,7 +586,10 @@ def visit_typeguard_type(self, t: TypeGuardType) -> Type:
538586
def expand_types(self, types: Iterable[Type]) -> list[Type]:
539587
a: list[Type] = []
540588
for t in types:
541-
a.append(t.accept(self))
589+
typ = t.accept(self)
590+
if isinstance(typ, VarianceModifier):
591+
typ = typ.value
592+
a.append(typ)
542593
return a
543594

544595

mypy/message_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,11 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
113113
)
114114
FORMAT_REQUIRES_MAPPING: Final = "Format requires a mapping"
115115
RETURN_TYPE_CANNOT_BE_CONTRAVARIANT: Final = ErrorMessage(
116-
"This usage of this contravariant type variable is unsafe as a return type.",
116+
"This usage of this contravariant type variable is unsafe as a return type",
117117
codes.UNSAFE_VARIANCE,
118118
)
119119
FUNCTION_PARAMETER_CANNOT_BE_COVARIANT: Final = ErrorMessage(
120-
"This usage of this covariant type variable is unsafe as an input parameter.",
120+
"This usage of this covariant type variable is unsafe as an input parameter",
121121
codes.UNSAFE_VARIANCE,
122122
)
123123
UNSAFE_VARIANCE_NOTE = ErrorMessage(

mypy/messages.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
UninhabitedType,
9494
UnionType,
9595
UnpackType,
96+
VarianceModifier,
9697
flatten_nested_unions,
9798
get_proper_type,
9899
get_proper_types,
@@ -2676,6 +2677,9 @@ def format_literal_value(typ: LiteralType) -> str:
26762677
type_str += f"[{format_list(typ.args)}]"
26772678
return type_str
26782679

2680+
if isinstance(typ, VarianceModifier):
2681+
return typ.render(format)
2682+
26792683
# TODO: always mention type alias names in errors.
26802684
typ = get_proper_type(typ)
26812685

mypy/plugins/proper_plugin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def is_special_target(right: ProperType) -> bool:
107107
"mypy.types.DeletedType",
108108
"mypy.types.RequiredType",
109109
"mypy.types.ReadOnlyType",
110+
"mypy.types.VarianceModifier",
110111
):
111112
# Special case: these are not valid targets for a type alias and thus safe.
112113
# TODO: introduce a SyntheticType base to simplify this?

mypy/semanal.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@
298298
UnionType,
299299
UnpackType,
300300
UntypedType,
301+
VarianceModifier,
301302
get_proper_type,
302303
get_proper_types,
303304
has_type_vars,
@@ -1962,14 +1963,25 @@ def analyze_type_param(
19621963
self, type_param: TypeParam, context: Context
19631964
) -> TypeVarLikeExpr | None:
19641965
fullname = self.qualified_name(type_param.name)
1966+
variance = VARIANCE_NOT_READY
1967+
upper_bound = None
19651968
if type_param.upper_bound:
1966-
upper_bound = self.anal_type(type_param.upper_bound, allow_placeholder=True)
1967-
# TODO: we should validate the upper bound is valid for a given kind.
1968-
if upper_bound is None:
1969-
# This and below copies special-casing for old-style type variables, that
1970-
# is equally necessary for new-style classes to break a vicious circle.
1971-
upper_bound = PlaceholderType(None, [], context.line)
1972-
else:
1969+
variance_or_bound = self.anal_type(
1970+
type_param.upper_bound,
1971+
allow_placeholder=True,
1972+
is_type_var_bound=isinstance(context, (ClassDef, TypeAliasStmt)),
1973+
)
1974+
if isinstance(variance_or_bound, VarianceModifier):
1975+
variance = variance_or_bound.variance
1976+
upper_bound = variance_or_bound._value
1977+
else:
1978+
upper_bound = variance_or_bound
1979+
# TODO: we should validate the upper bound is valid for a given kind.
1980+
if upper_bound is None:
1981+
# This and below copies special-casing for old-style type variables, that
1982+
# is equally necessary for new-style classes to break a vicious circle.
1983+
upper_bound = PlaceholderType(None, [], context.line)
1984+
if upper_bound is None:
19731985
if type_param.kind == TYPE_VAR_TUPLE_KIND:
19741986
upper_bound = self.named_type("builtins.tuple", [self.object_type()])
19751987
else:
@@ -2012,7 +2024,7 @@ def analyze_type_param(
20122024
values=values,
20132025
upper_bound=upper_bound,
20142026
default=default,
2015-
variance=VARIANCE_NOT_READY,
2027+
variance=variance,
20162028
is_new_style=True,
20172029
line=context.line,
20182030
)
@@ -6261,6 +6273,7 @@ def analyze_type_application_args(self, expr: IndexExpr) -> list[Type] | None:
62616273
allow_param_spec_literals=has_param_spec,
62626274
allow_unpack=allow_unpack,
62636275
runtime=True,
6276+
nested=True,
62646277
)
62656278
if analyzed is None:
62666279
return None
@@ -7545,6 +7558,7 @@ def type_analyzer(
75457558
report_invalid_types: bool = True,
75467559
prohibit_self_type: str | None = None,
75477560
allow_type_any: bool = False,
7561+
is_type_var_bound=False,
75487562
) -> TypeAnalyser:
75497563
if tvar_scope is None:
75507564
tvar_scope = self.tvar_scope
@@ -7564,6 +7578,7 @@ def type_analyzer(
75647578
allow_unpack=allow_unpack,
75657579
prohibit_self_type=prohibit_self_type,
75667580
allow_type_any=allow_type_any,
7581+
is_type_var_bound=is_type_var_bound,
75677582
)
75687583
tpan.in_dynamic_func = bool(self.function_stack and self.function_stack[-1].is_dynamic())
75697584
tpan.global_scope = not self.type and not self.function_stack
@@ -7589,6 +7604,8 @@ def anal_type(
75897604
prohibit_self_type: str | None = None,
75907605
allow_type_any: bool = False,
75917606
runtime: bool | None = None,
7607+
is_type_var_bound=False,
7608+
nested=False,
75927609
) -> Type | None:
75937610
"""Semantically analyze a type.
75947611
@@ -7624,13 +7641,16 @@ def anal_type(
76247641
report_invalid_types=report_invalid_types,
76257642
prohibit_self_type=prohibit_self_type,
76267643
allow_type_any=allow_type_any,
7644+
is_type_var_bound=is_type_var_bound,
76277645
)
76287646
if not a.api.is_stub_file and runtime:
76297647
a.always_allow_new_syntax = False
76307648
if runtime is False:
76317649
a.always_allow_new_syntax = True
76327650
if self.is_stub_file:
76337651
a.always_allow_new_syntax = True
7652+
if nested:
7653+
a.nesting_level += 1
76347654
tag = self.track_incomplete_refs()
76357655
typ = typ.accept(a)
76367656
if self.found_incomplete_ref(tag):

mypy/subtypes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
UnionType,
6868
UnpackType,
6969
UntypedType,
70+
VarianceModifier,
7071
find_unpack_in_list,
7172
get_proper_type,
7273
is_named_instance,
@@ -378,6 +379,8 @@ def check_type_parameter(
378379
p_left = get_proper_type(left)
379380
if isinstance(p_left, UninhabitedType) and p_left.ambiguous:
380381
variance = COVARIANT
382+
if isinstance(right, VarianceModifier):
383+
variance = right.variance
381384
# If variance hasn't been inferred yet, we are lenient and default to
382385
# covariance. This shouldn't happen often, but it's very difficult to
383386
# avoid these cases altogether.

mypy/type_visitor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
UninhabitedType,
5050
UnionType,
5151
UnpackType,
52+
VarianceModifier,
5253
get_proper_type,
5354
)
5455

@@ -87,6 +88,10 @@ def visit_erased_type(self, t: ErasedType) -> T:
8788
def visit_deleted_type(self, t: DeletedType) -> T:
8889
pass
8990

91+
def visit_variance_modifier(self, t: VarianceModifier) -> T:
92+
assert t.value
93+
return t.value.accept(self)
94+
9095
@abstractmethod
9196
def visit_type_var(self, t: TypeVarType) -> T:
9297
pass
@@ -245,6 +250,9 @@ def visit_instance(self, t: Instance) -> Type:
245250
result.metadata = t.metadata
246251
return result
247252

253+
def visit_variance_modifier(self, t: VarianceModifier) -> Type:
254+
return VarianceModifier(t.variance, t.value.accept(self))
255+
248256
def visit_type_var(self, t: TypeVarType) -> Type:
249257
return t
250258

0 commit comments

Comments
 (0)