Skip to content

Commit d6f1faa

Browse files
rework starargs with union argument
1 parent 5a78607 commit d6f1faa

File tree

6 files changed

+577
-188
lines changed

6 files changed

+577
-188
lines changed

mypy/argmap.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from typing import TYPE_CHECKING, Callable
77

88
from mypy import nodes
9+
from mypy.join import join_type_list
910
from mypy.maptype import map_instance_to_supertype
11+
from mypy.typeops import make_simplified_union
1012
from mypy.types import (
1113
AnyType,
1214
Instance,
@@ -16,6 +18,7 @@
1618
TypedDictType,
1719
TypeOfAny,
1820
TypeVarTupleType,
21+
UnionType,
1922
UnpackType,
2023
get_proper_type,
2124
)
@@ -54,6 +57,15 @@ def map_actuals_to_formals(
5457
elif actual_kind == nodes.ARG_STAR:
5558
# We need to know the actual type to map varargs.
5659
actualt = get_proper_type(actual_arg_type(ai))
60+
61+
# Special case for union of equal sized tuples.
62+
if (
63+
isinstance(actualt, UnionType)
64+
and actualt.items
65+
and is_equal_sized_tuples(actualt.items)
66+
):
67+
# Arbitrarily pick the first item in the union.
68+
actualt = get_proper_type(actualt.items[0])
5769
if isinstance(actualt, TupleType):
5870
# A tuple actual maps to a fixed number of formals.
5971
for _ in range(len(actualt.items)):
@@ -171,6 +183,15 @@ def __init__(self, context: ArgumentInferContext) -> None:
171183
# Type context for `*` and `**` arg kinds.
172184
self.context = context
173185

186+
def __eq__(self, other: object) -> bool:
187+
if isinstance(other, ArgTypeExpander):
188+
return (
189+
self.tuple_index == other.tuple_index
190+
and self.kwargs_used == other.kwargs_used
191+
and self.context == other.context
192+
)
193+
return NotImplemented
194+
174195
def expand_actual_type(
175196
self,
176197
actual_type: Type,
@@ -193,6 +214,64 @@ def expand_actual_type(
193214
original_actual = actual_type
194215
actual_type = get_proper_type(actual_type)
195216
if actual_kind == nodes.ARG_STAR:
217+
if isinstance(actual_type, UnionType):
218+
# special case 1: union of equal sized tuples. (e.g. `tuple[int, int] | tuple[None, None]`)
219+
# special case 2: union contains no static sized tuples. (e.g. `list[str | None] | list[str]`)
220+
if is_equal_sized_tuples(actual_type.items) or not any(
221+
isinstance(get_proper_type(t), TupleType) for t in actual_type.items
222+
):
223+
# If the actual type is a union, try expanding it.
224+
# Example: f(*args), where args is `list[str | None] | list[str]`,
225+
# Example: f(*args), where args is `tuple[A, B, C] | tuple[None, None, None]`
226+
# Note: there is potential for combinatorial explosion here:
227+
# f(*x1, *x2, .. *xn), if xₖ is a union of nₖ differently sized tuples,
228+
# then there are n₁ * n₂ * ... * nₖ possible combinations of pointer positions.
229+
# therefore, we only take this branch if all union members consume the same number of items.
230+
231+
# create copies of self for each item in the union
232+
sub_expanders = [
233+
ArgTypeExpander(context=self.context) for _ in actual_type.items
234+
]
235+
for expander in sub_expanders:
236+
expander.tuple_index = int(self.tuple_index)
237+
expander.kwargs_used = set(self.kwargs_used)
238+
239+
candidate_type = make_simplified_union(
240+
[
241+
e.expand_actual_type(
242+
item, actual_kind, formal_name, formal_kind, allow_unpack
243+
)
244+
for e, item in zip(sub_expanders, actual_type.items)
245+
]
246+
)
247+
assert all(expander == sub_expanders[0] for expander in sub_expanders)
248+
# carry over the new state if all sub-expanders are the same state
249+
self.tuple_index = int(sub_expanders[0].tuple_index)
250+
self.kwargs_used = set(sub_expanders[0].kwargs_used)
251+
return candidate_type
252+
else:
253+
# otherwise, we fall back to checking using the join of the union members.
254+
# for better results we first map all instances to Iterable[T]
255+
from mypy.subtypes import is_subtype
256+
257+
iterable_type = self.context.iterable_type
258+
259+
def as_iterable_type(t: Type) -> Type:
260+
"""Map a type to the iterable supertype if it is a subtype."""
261+
p_t = get_proper_type(t)
262+
if isinstance(p_t, Instance) and is_subtype(t, iterable_type):
263+
return map_instance_to_supertype(p_t, iterable_type.type)
264+
if isinstance(p_t, TupleType):
265+
# Convert tuple[A, B, C] to Iterable[A | B | C].
266+
return Instance(iterable_type.type, [make_simplified_union(p_t.items)])
267+
return t
268+
269+
joined_type = join_type_list([as_iterable_type(t) for t in actual_type.items])
270+
assert not isinstance(get_proper_type(joined_type), TupleType)
271+
return self.expand_actual_type(
272+
joined_type, actual_kind, formal_name, formal_kind, allow_unpack
273+
)
274+
196275
if isinstance(actual_type, TypeVarTupleType):
197276
# This code path is hit when *Ts is passed to a callable and various
198277
# special-handling didn't catch this. The best thing we can do is to use
@@ -265,3 +344,21 @@ def expand_actual_type(
265344
else:
266345
# No translation for other kinds -- 1:1 mapping.
267346
return original_actual
347+
348+
349+
def is_equal_sized_tuples(types: Sequence[Type]) -> bool:
350+
"""Check if all types are tuples of the same size."""
351+
if not types:
352+
return True
353+
354+
iterator = iter(types)
355+
first = get_proper_type(next(iterator))
356+
if not isinstance(first, TupleType):
357+
return False
358+
size = first.length()
359+
360+
for item in iterator:
361+
p_t = get_proper_type(item)
362+
if not isinstance(p_t, TupleType) or p_t.length() != size:
363+
return False
364+
return True

mypy/checkexpr.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
freshen_function_type_vars,
2828
)
2929
from mypy.infer import ArgumentInferContext, infer_function_type_arguments, infer_type_arguments
30+
from mypy.join import join_type_list
3031
from mypy.literals import literal
3132
from mypy.maptype import map_instance_to_supertype
3233
from mypy.meet import is_overlapping_types, narrow_declared_type
@@ -5227,6 +5228,11 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type:
52275228
ctx = None
52285229
tt = self.accept(item.expr, ctx)
52295230
tt = get_proper_type(tt)
5231+
if isinstance(tt, UnionType):
5232+
# Coercing union to join allows better inference in some
5233+
# special cases like `tuple[A, B] | tuple[C, D]`
5234+
tt = get_proper_type(join_type_list(tt.items))
5235+
52305236
if isinstance(tt, TupleType):
52315237
if find_unpack_in_list(tt.items) is not None:
52325238
if seen_unpack_in_items:

mypy/constraints.py

Lines changed: 2 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,14 @@
33
from __future__ import annotations
44

55
from collections.abc import Iterable, Sequence
6-
from typing import TYPE_CHECKING, Final, cast
6+
from typing import Final, cast
77
from typing_extensions import TypeGuard
88

99
import mypy.subtypes
1010
import mypy.typeops
11-
from mypy.argmap import ArgTypeExpander
1211
from mypy.erasetype import erase_typevars
1312
from mypy.maptype import map_instance_to_supertype
14-
from mypy.nodes import (
15-
ARG_OPT,
16-
ARG_POS,
17-
ARG_STAR,
18-
ARG_STAR2,
19-
CONTRAVARIANT,
20-
COVARIANT,
21-
ArgKind,
22-
TypeInfo,
23-
)
13+
from mypy.nodes import ARG_OPT, ARG_POS, ARG_STAR, ARG_STAR2, CONTRAVARIANT, COVARIANT, TypeInfo
2414
from mypy.types import (
2515
TUPLE_LIKE_INSTANCE_NAMES,
2616
AnyType,
@@ -63,9 +53,6 @@
6353
from mypy.types_utils import is_union_with_any
6454
from mypy.typestate import type_state
6555

66-
if TYPE_CHECKING:
67-
from mypy.infer import ArgumentInferContext
68-
6956
SUBTYPE_OF: Final = 0
7057
SUPERTYPE_OF: Final = 1
7158

@@ -107,175 +94,6 @@ def __eq__(self, other: object) -> bool:
10794
return (self.type_var, self.op, self.target) == (other.type_var, other.op, other.target)
10895

10996

110-
def infer_constraints_for_callable(
111-
callee: CallableType,
112-
arg_types: Sequence[Type | None],
113-
arg_kinds: list[ArgKind],
114-
arg_names: Sequence[str | None] | None,
115-
formal_to_actual: list[list[int]],
116-
context: ArgumentInferContext,
117-
) -> list[Constraint]:
118-
"""Infer type variable constraints for a callable and actual arguments.
119-
120-
Return a list of constraints.
121-
"""
122-
constraints: list[Constraint] = []
123-
mapper = ArgTypeExpander(context)
124-
125-
param_spec = callee.param_spec()
126-
param_spec_arg_types = []
127-
param_spec_arg_names = []
128-
param_spec_arg_kinds = []
129-
130-
incomplete_star_mapping = False
131-
for i, actuals in enumerate(formal_to_actual): # TODO: isn't this `enumerate(arg_types)`?
132-
for actual in actuals:
133-
if actual is None and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2): # type: ignore[unreachable]
134-
# We can't use arguments to infer ParamSpec constraint, if only some
135-
# are present in the current inference pass.
136-
incomplete_star_mapping = True # type: ignore[unreachable]
137-
break
138-
139-
for i, actuals in enumerate(formal_to_actual):
140-
if isinstance(callee.arg_types[i], UnpackType):
141-
unpack_type = callee.arg_types[i]
142-
assert isinstance(unpack_type, UnpackType)
143-
144-
# In this case we are binding all the actuals to *args,
145-
# and we want a constraint that the typevar tuple being unpacked
146-
# is equal to a type list of all the actuals.
147-
actual_types = []
148-
149-
unpacked_type = get_proper_type(unpack_type.type)
150-
if isinstance(unpacked_type, TypeVarTupleType):
151-
tuple_instance = unpacked_type.tuple_fallback
152-
elif isinstance(unpacked_type, TupleType):
153-
tuple_instance = unpacked_type.partial_fallback
154-
else:
155-
assert False, "mypy bug: unhandled constraint inference case"
156-
157-
for actual in actuals:
158-
actual_arg_type = arg_types[actual]
159-
if actual_arg_type is None:
160-
continue
161-
162-
expanded_actual = mapper.expand_actual_type(
163-
actual_arg_type,
164-
arg_kinds[actual],
165-
callee.arg_names[i],
166-
callee.arg_kinds[i],
167-
allow_unpack=True,
168-
)
169-
170-
if arg_kinds[actual] != ARG_STAR or isinstance(
171-
get_proper_type(actual_arg_type), TupleType
172-
):
173-
actual_types.append(expanded_actual)
174-
else:
175-
# If we are expanding an iterable inside * actual, append a homogeneous item instead
176-
actual_types.append(
177-
UnpackType(tuple_instance.copy_modified(args=[expanded_actual]))
178-
)
179-
180-
if isinstance(unpacked_type, TypeVarTupleType):
181-
constraints.append(
182-
Constraint(
183-
unpacked_type,
184-
SUPERTYPE_OF,
185-
TupleType(actual_types, unpacked_type.tuple_fallback),
186-
)
187-
)
188-
elif isinstance(unpacked_type, TupleType):
189-
# Prefixes get converted to positional args, so technically the only case we
190-
# should have here is like Tuple[Unpack[Ts], Y1, Y2, Y3]. If this turns out
191-
# not to hold we can always handle the prefixes too.
192-
inner_unpack = unpacked_type.items[0]
193-
assert isinstance(inner_unpack, UnpackType)
194-
inner_unpacked_type = get_proper_type(inner_unpack.type)
195-
suffix_len = len(unpacked_type.items) - 1
196-
if isinstance(inner_unpacked_type, TypeVarTupleType):
197-
# Variadic item can be either *Ts...
198-
constraints.append(
199-
Constraint(
200-
inner_unpacked_type,
201-
SUPERTYPE_OF,
202-
TupleType(
203-
actual_types[:-suffix_len], inner_unpacked_type.tuple_fallback
204-
),
205-
)
206-
)
207-
else:
208-
# ...or it can be a homogeneous tuple.
209-
assert (
210-
isinstance(inner_unpacked_type, Instance)
211-
and inner_unpacked_type.type.fullname == "builtins.tuple"
212-
)
213-
for at in actual_types[:-suffix_len]:
214-
constraints.extend(
215-
infer_constraints(inner_unpacked_type.args[0], at, SUPERTYPE_OF)
216-
)
217-
# Now handle the suffix (if any).
218-
if suffix_len:
219-
for tt, at in zip(unpacked_type.items[1:], actual_types[-suffix_len:]):
220-
constraints.extend(infer_constraints(tt, at, SUPERTYPE_OF))
221-
else:
222-
assert False, "mypy bug: unhandled constraint inference case"
223-
else:
224-
for actual in actuals:
225-
actual_arg_type = arg_types[actual]
226-
if actual_arg_type is None:
227-
continue
228-
229-
if param_spec and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2):
230-
# If actual arguments are mapped to ParamSpec type, we can't infer individual
231-
# constraints, instead store them and infer single constraint at the end.
232-
# It is impossible to map actual kind to formal kind, so use some heuristic.
233-
# This inference is used as a fallback, so relying on heuristic should be OK.
234-
if not incomplete_star_mapping:
235-
param_spec_arg_types.append(
236-
mapper.expand_actual_type(
237-
actual_arg_type, arg_kinds[actual], None, arg_kinds[actual]
238-
)
239-
)
240-
actual_kind = arg_kinds[actual]
241-
param_spec_arg_kinds.append(
242-
ARG_POS if actual_kind not in (ARG_STAR, ARG_STAR2) else actual_kind
243-
)
244-
param_spec_arg_names.append(arg_names[actual] if arg_names else None)
245-
else:
246-
actual_type = mapper.expand_actual_type(
247-
actual_arg_type,
248-
arg_kinds[actual],
249-
callee.arg_names[i],
250-
callee.arg_kinds[i],
251-
)
252-
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
253-
constraints.extend(c)
254-
if (
255-
param_spec
256-
and not any(c.type_var == param_spec.id for c in constraints)
257-
and not incomplete_star_mapping
258-
):
259-
# Use ParamSpec constraint from arguments only if there are no other constraints,
260-
# since as explained above it is quite ad-hoc.
261-
constraints.append(
262-
Constraint(
263-
param_spec,
264-
SUPERTYPE_OF,
265-
Parameters(
266-
arg_types=param_spec_arg_types,
267-
arg_kinds=param_spec_arg_kinds,
268-
arg_names=param_spec_arg_names,
269-
imprecise_arg_kinds=True,
270-
),
271-
)
272-
)
273-
if any(isinstance(v, ParamSpecType) for v in callee.variables):
274-
# As a perf optimization filter imprecise constraints only when we can have them.
275-
constraints = filter_imprecise_kinds(constraints)
276-
return constraints
277-
278-
27997
def infer_constraints(
28098
template: Type, actual: Type, direction: int, skip_neg_op: bool = False
28199
) -> list[Constraint]:

0 commit comments

Comments
 (0)