diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index a8afebbd9923..d6982a6bed43 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -5069,31 +5069,56 @@ def fast_container_type( module-level constant definitions. Limitations: + - no active type context + - at least one item - no star expressions - - the joined type of all entries must be an Instance or Tuple type + - not after deferral + - either exactly one distinct type inside, + or the joined type of all entries is an Instance or Tuple type, """ ctx = self.type_context[-1] - if ctx: + if ctx or not e.items: + return None + if self.chk.current_node_deferred: + # Guarantees that all items will be Any, we'll reject it anyway. return None rt = self.resolved_type.get(e, None) if rt is not None: return rt if isinstance(rt, Instance) else None values: list[Type] = [] + # Preserve join order while avoiding O(n) lookups at every iteration + values_set: set[Type] = set() for item in e.items: if isinstance(item, StarExpr): # fallback to slow path self.resolved_type[e] = NoneType() return None - values.append(self.accept(item)) - vt = join.join_type_list(values) - if not allow_fast_container_literal(vt): + + typ = self.accept(item) + if typ not in values_set: + values.append(typ) + values_set.add(typ) + + vt = self._first_or_join_fast_item(values) + if vt is None: self.resolved_type[e] = NoneType() return None ct = self.chk.named_generic_type(container_fullname, [vt]) self.resolved_type[e] = ct return ct + def _first_or_join_fast_item(self, items: list[Type]) -> Type | None: + if len(items) == 1 and not self.chk.current_node_deferred: + return items[0] + typ = join.join_type_list(items) + if not allow_fast_container_literal(typ): + # TODO: This is overly strict, many other types can be joined safely here. + # However, our join implementation isn't bug-free, and some joins may produce + # undesired `Any`s or even more surprising results. + return None + return typ + def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag: str) -> Type: # fast path t = self.fast_container_type(e, fullname) @@ -5254,18 +5279,30 @@ def fast_dict_type(self, e: DictExpr) -> Type | None: module-level constant definitions. Limitations: + - no active type context + - at least one item - only supported star expressions are other dict instances - - the joined types of all keys and values must be Instance or Tuple types + - either exactly one distinct type (keys and values separately) inside, + or the joined type of all entries is an Instance or Tuple type """ ctx = self.type_context[-1] - if ctx: + if ctx or not e.items: return None + + if self.chk.current_node_deferred: + # Guarantees that all items will be Any, we'll reject it anyway. + return None + rt = self.resolved_type.get(e, None) if rt is not None: return rt if isinstance(rt, Instance) else None + keys: list[Type] = [] values: list[Type] = [] + # Preserve join order while avoiding O(n) lookups at every iteration + keys_set: set[Type] = set() + values_set: set[Type] = set() stargs: tuple[Type, Type] | None = None for key, value in e.items: if key is None: @@ -5280,13 +5317,25 @@ def fast_dict_type(self, e: DictExpr) -> Type | None: self.resolved_type[e] = NoneType() return None else: - keys.append(self.accept(key)) - values.append(self.accept(value)) - kt = join.join_type_list(keys) - vt = join.join_type_list(values) - if not (allow_fast_container_literal(kt) and allow_fast_container_literal(vt)): + key_t = self.accept(key) + if key_t not in keys_set: + keys.append(key_t) + keys_set.add(key_t) + value_t = self.accept(value) + if value_t not in values_set: + values.append(value_t) + values_set.add(value_t) + + kt = self._first_or_join_fast_item(keys) + if kt is None: self.resolved_type[e] = NoneType() return None + + vt = self._first_or_join_fast_item(values) + if vt is None: + self.resolved_type[e] = NoneType() + return None + if stargs and (stargs[0] != kt or stargs[1] != vt): self.resolved_type[e] = NoneType() return None diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 78680684f69b..abeb5face26f 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2929,8 +2929,8 @@ def mix(fs: List[Callable[[S], T]]) -> Callable[[S], List[T]]: def id(__x: U) -> U: ... fs = [id, id, id] -reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`7) -> builtins.list[S`7]" -reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`9) -> builtins.list[S`9]" +reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`2) -> builtins.list[S`2]" +reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`4) -> builtins.list[S`4]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCurry] @@ -3118,11 +3118,11 @@ def dec4_bound(f: Callable[[I], List[T]]) -> Callable[[I], T]: reveal_type(dec1(lambda x: x)) # N: Revealed type is "def [T] (T`3) -> builtins.list[T`3]" reveal_type(dec2(lambda x: x)) # N: Revealed type is "def [S] (S`5) -> builtins.list[S`5]" reveal_type(dec3(lambda x: x[0])) # N: Revealed type is "def [S] (S`8) -> S`8" -reveal_type(dec4(lambda x: [x])) # N: Revealed type is "def [S] (S`12) -> S`12" +reveal_type(dec4(lambda x: [x])) # N: Revealed type is "def [S] (S`11) -> S`11" reveal_type(dec1(lambda x: 1)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]" reveal_type(dec5(lambda x: x)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]" -reveal_type(dec3(lambda x: x)) # N: Revealed type is "def [S] (S`20) -> builtins.list[S`20]" -reveal_type(dec4(lambda x: x)) # N: Revealed type is "def [T] (builtins.list[T`24]) -> T`24" +reveal_type(dec3(lambda x: x)) # N: Revealed type is "def [S] (S`19) -> builtins.list[S`19]" +reveal_type(dec4(lambda x: x)) # N: Revealed type is "def [T] (builtins.list[T`23]) -> T`23" dec4_bound(lambda x: x) # E: Value of type variable "I" of "dec4_bound" cannot be "list[T]" [builtins fixtures/list.pyi] diff --git a/test-data/unit/check-redefine2.test b/test-data/unit/check-redefine2.test index 3523772611aa..1abe957240b5 100644 --- a/test-data/unit/check-redefine2.test +++ b/test-data/unit/check-redefine2.test @@ -1073,7 +1073,7 @@ def f() -> None: while int(): x = [x] - reveal_type(x) # N: Revealed type is "Union[Any, builtins.list[Any], builtins.list[Union[Any, builtins.list[Any]]], builtins.list[Union[Any, builtins.list[Any], builtins.list[Union[Any, builtins.list[Any]]]]], builtins.list[Union[Any, builtins.list[Any], builtins.list[Union[Any, builtins.list[Any]]], builtins.list[Union[Any, builtins.list[Any], builtins.list[Union[Any, builtins.list[Any]]]]]]]]" + reveal_type(x) # N: Revealed type is "Union[Any, builtins.list[Any]]" [case testNewRedefinePartialNoneEmptyList] # flags: --allow-redefinition-new --local-partial-types diff --git a/test-data/unit/check-selftype.test b/test-data/unit/check-selftype.test index 88ca53c8ed66..05c34eb70796 100644 --- a/test-data/unit/check-selftype.test +++ b/test-data/unit/check-selftype.test @@ -2018,7 +2018,7 @@ class Ben(Object): } @classmethod def doit(cls) -> Foo: - reveal_type(cls.MY_MAP) # N: Revealed type is "builtins.dict[builtins.str, def [Self <: __main__.Foo] (self: Self`4) -> Self`4]" + reveal_type(cls.MY_MAP) # N: Revealed type is "builtins.dict[builtins.str, def [Self <: __main__.Foo] (self: Self`1) -> Self`1]" foo_method = cls.MY_MAP["foo"] return foo_method(Foo()) [builtins fixtures/isinstancelist.pyi]