Skip to content

narrow to literal value #852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Basedmypy Changelog

## [Unreleased]
### Added
- `float` and `complex` literals
- narrow types to literal values
- infer literal in generics

## [2.9.0]
### Added
Expand Down
4 changes: 2 additions & 2 deletions mypy/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import DefaultDict, Generator, Iterator, List, NamedTuple, Optional, Tuple, Union
from typing_extensions import TypeAlias as _TypeAlias

from mypy.erasetype import remove_instance_last_known_values
from mypy.join import join_simple
from mypy.literals import Key, literal, literal_hash, subkeys
from mypy.nodes import Expression, IndexExpr, MemberExpr, NameExpr, RefExpr, TypeInfo, Var
Expand Down Expand Up @@ -331,7 +330,8 @@ def assign_type(
) -> None:
# We should erase last known value in binder, because if we are using it,
# it means that the target is not final, and therefore can't hold a literal.
type = remove_instance_last_known_values(type)
# HUUHHH?????
# type = remove_instance_last_known_values(type)

if self.type_assignments is not None:
# We are in a multiassign from union, defer the actual binding,
Expand Down
19 changes: 16 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3589,19 +3589,30 @@ def check_assignment(
):
lvalue.node.type = remove_instance_last_known_values(lvalue_type)

elif lvalue.node and lvalue.node.is_inferred and rvalue_type:
# for literal values
# Don't use type binder for definitions of special forms, like named tuples.
if not (isinstance(lvalue, NameExpr) and lvalue.is_special_form):
self.binder.assign_type(lvalue, rvalue_type, lvalue_type, False)

elif index_lvalue:
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)

if inferred:
type_context = self.get_variable_type_context(inferred)
rvalue_type = self.expr_checker.accept(rvalue, type_context=type_context)
original_rvalue_type = rvalue_type
if not (
inferred.is_final
or inferred.is_index_var
or (isinstance(lvalue, NameExpr) and lvalue.name == "__match_args__")
):
rvalue_type = remove_instance_last_known_values(rvalue_type)
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)
if self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue) or self.binder.type_assignments:
# we don't always want to assign the type here as it might be something like partial
self.binder.assign_type(
lvalue, original_rvalue_type, original_rvalue_type, False
)
self.check_assignment_to_slots(lvalue)

# (type, operator) tuples for augmented assignments supported with partial types
Expand Down Expand Up @@ -4553,12 +4564,13 @@ def is_definition(self, s: Lvalue) -> bool:

def infer_variable_type(
self, name: Var, lvalue: Lvalue, init_type: Type, context: Context
) -> None:
) -> bool:
"""Infer the type of initialized variables from initializer type."""
valid = True
if isinstance(init_type, DeletedType):
self.msg.deleted_as_rvalue(init_type, context)
elif (
not is_valid_inferred_type(init_type, is_lvalue_final=name.is_final)
not (valid := is_valid_inferred_type(init_type, is_lvalue_final=name.is_final))
and not self.no_partial_types
):
# We cannot use the type of the initialization expression for full type
Expand All @@ -4585,6 +4597,7 @@ def infer_variable_type(
init_type = strip_type(init_type)

self.set_inferred_type(name, lvalue, init_type)
return valid

def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool:
init_type = get_proper_type(init_type)
Expand Down
24 changes: 22 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@
# Type of callback user for checking individual function arguments. See
# check_args() below for details.
ArgChecker: _TypeAlias = Callable[
[Type, Type, ArgKind, Type, int, int, CallableType, Optional[Type], Context, Context], None
[Type, Type, ArgKind, Type, int, int, CallableType, Optional[Type], Context, Context, bool],
None,
]

# Maximum nesting level for math union in overloads, setting this to large values
Expand Down Expand Up @@ -2175,6 +2176,13 @@ def infer_function_type_arguments(
Return a derived callable type that has the arguments applied.
"""
if self.chk.in_checked_function():
if isinstance(callee_type.ret_type, TypeVarType):
# if the return type is constant, infer as literal
rvalue_type = [
remove_instance_last_known_values(arg) if isinstance(arg, Instance) else arg
for arg in args
]

# Disable type errors during type inference. There may be errors
# due to partial available context information at this time, but
# these errors can be safely ignored as the arguments will be
Expand Down Expand Up @@ -2581,6 +2589,8 @@ def check_argument_types(
context: Context,
check_arg: ArgChecker | None = None,
object_type: Type | None = None,
*,
type_function=False,
) -> None:
"""Check argument types against a callable type.

Expand Down Expand Up @@ -2712,6 +2722,7 @@ def check_argument_types(
object_type,
args[actual],
context,
type_function,
)

def check_arg(
Expand All @@ -2726,12 +2737,16 @@ def check_arg(
object_type: Type | None,
context: Context,
outer_context: Context,
type_function=False,
) -> None:
"""Check the type of a single argument in a call."""
caller_type = get_proper_type(caller_type)
original_caller_type = get_proper_type(original_caller_type)
callee_type = get_proper_type(callee_type)

if type_function:
# TODO: make this work at all
if not isinstance(caller_type, Instance) or not caller_type.last_known_value:
caller_type = self.named_type("builtins.object")
if isinstance(caller_type, DeletedType):
self.msg.deleted_as_rvalue(caller_type, context)
# Only non-abstract non-protocol class can be given where Type[...] is expected...
Expand Down Expand Up @@ -3348,6 +3363,7 @@ def check_arg(
object_type: Type | None,
context: Context,
outer_context: Context,
type_function: bool,
) -> None:
if not arg_approximate_similarity(caller_type, callee_type):
# No match -- exit early since none of the remaining work can change
Expand Down Expand Up @@ -3580,10 +3596,14 @@ def visit_bytes_expr(self, e: BytesExpr) -> Type:

def visit_float_expr(self, e: FloatExpr) -> Type:
"""Type check a float literal (trivial)."""
if mypy.options._based:
return self.infer_literal_expr_type(e.value, "builtins.float")
return self.named_type("builtins.float")

def visit_complex_expr(self, e: ComplexExpr) -> Type:
"""Type check a complex literal."""
if mypy.options._based:
return self.infer_literal_expr_type(e.value, "builtins.complex")
return self.named_type("builtins.complex")

def visit_ellipsis(self, e: EllipsisExpr) -> Type:
Expand Down
43 changes: 29 additions & 14 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import Final, Iterable, Mapping, Sequence, TypeVar, cast, overload

from mypy.nodes import ARG_STAR, FakeInfo, Var
Expand Down Expand Up @@ -185,6 +186,16 @@ def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
super().__init__()
self.variables = variables
self.recursive_tvar_guard: dict[TypeVarId, Type | None] = {}
self._erase_literals = False

@contextmanager
def erase_literals(self):
_erase_literals = self._erase_literals
self._erase_literals = True
try:
yield
finally:
self._erase_literals = _erase_literals

def visit_unbound_type(self, t: UnboundType) -> Type:
return t
Expand All @@ -211,7 +222,8 @@ def visit_erased_type(self, t: ErasedType) -> Type:
return t

def visit_instance(self, t: Instance) -> Type:
args = self.expand_types_with_unpack(list(t.args))
with self.erase_literals():
args = self.expand_types_with_unpack(list(t.args))

if isinstance(t.type, FakeInfo):
# The type checker expands function definitions and bodies
Expand All @@ -238,7 +250,7 @@ def visit_type_var(self, t: TypeVarType) -> Type:
if t.id.is_self():
t = t.copy_modified(upper_bound=t.upper_bound.accept(self))
repl = self.variables.get(t.id, t)
if isinstance(repl, ProperType) and isinstance(repl, Instance):
if self._erase_literals and isinstance(repl, ProperType) and isinstance(repl, Instance):
# TODO: do we really need to do this?
# If I try to remove this special-casing ~40 tests fail on reveal_type().
return repl.copy_modified(last_known_value=None)
Expand Down Expand Up @@ -410,17 +422,18 @@ def visit_callable_type(self, t: CallableType) -> CallableType:

var_arg = t.var_arg()
needs_normalization = False
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
needs_normalization = True
arg_types = self.interpolate_args_for_unpack(t, var_arg.typ)
else:
arg_types = self.expand_types(t.arg_types)
expanded = t.copy_modified(
arg_types=arg_types,
ret_type=t.ret_type.accept(self),
type_guard=t.type_guard and cast(TypeGuardType, t.type_guard.accept(self)),
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
)
with self.erase_literals():
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
needs_normalization = True
arg_types = self.interpolate_args_for_unpack(t, var_arg.typ)
else:
arg_types = self.expand_types(t.arg_types)
expanded = t.copy_modified(
arg_types=arg_types,
ret_type=t.ret_type.accept(self),
type_guard=t.type_guard and cast(TypeGuardType, t.type_guard.accept(self)),
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
)
if needs_normalization:
return expanded.with_normalized_var_args()
return expanded
Expand Down Expand Up @@ -467,7 +480,9 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
return cached
fallback = t.fallback.accept(self)
assert isinstance(fallback, ProperType) and isinstance(fallback, Instance)
result = t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)
with self.erase_literals():
# TODO: we don't want to erase literals for `ReadOnly` keys
result = t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)
self.set_cached(t, result)
return result

Expand Down
3 changes: 2 additions & 1 deletion mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2129,7 +2129,8 @@ def numeric_type(self, value: object, n: AST) -> Type:
# Other kinds of numbers (floats, complex) are not valid parameters for
# RawExpressionType so we just pass in 'None' for now. We'll report the
# appropriate error at a later stage.
numeric_value = None
# based: they are valid
numeric_value = value
type_name = f"builtins.{type(value).__name__}"
return RawExpressionType(
numeric_value, type_name, line=self.line, column=getattr(n, "col_offset", -1)
Expand Down
2 changes: 1 addition & 1 deletion mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
if declaration is None or is_subtype(value, declaration):
return value

return declaration
return value


def trivial_join(s: Type, t: Type) -> Type:
Expand Down
5 changes: 4 additions & 1 deletion mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,12 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type:
# Special case: 'int' can't be narrowed down to a native int type such as
# i64, since they have different runtime representations.
return original_declared
if isinstance(narrowed, Instance) and narrowed.last_known_value:
return narrowed
return meet_types(original_declared, original_narrowed)
elif isinstance(declared, (TupleType, TypeType, LiteralType)):
return meet_types(original_declared, original_narrowed)
# this way around to preserve the last know of the items in the tuple
return meet_types(original_narrowed, original_declared, intersect=True)
elif isinstance(declared, TypedDictType) and isinstance(narrowed, Instance):
# Special case useful for selecting TypedDicts from unions using isinstance(x, dict).
if narrowed.type.fullname == "builtins.dict" and all(
Expand Down
6 changes: 6 additions & 0 deletions mypy/test/testtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,8 +694,14 @@ def test_simplified_intersection(self):
[fx.b, IntersectionType([fx.c, IntersectionType([fx.d])])],
IntersectionType([fx.b, fx.c, fx.d]),
)
self.assert_simplified_intersection([fx.bool_type, fx.lit_true], fx.lit_true)

# special case: it's not currently symmetric when there are last known values
narrowed = fx.bool_type.copy_modified(last_known_value=fx.lit_true)
assert_equal(make_simplified_intersection([narrowed, fx.bool_type]), narrowed)

def assert_simplified_intersection(self, original: list[Type], intersection: Type) -> None:
__tracebackhide__ = True
assert_equal(make_simplified_intersection(original), intersection)
assert_equal(make_simplified_intersection(list(reversed(original))), intersection)

Expand Down
14 changes: 11 additions & 3 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,6 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool)
return res
elif isinstance(node, TypeInfo):
return self.analyze_type_with_type_info(node, t.args, t, t.empty_tuple_index)

elif node.fullname in TYPE_ALIAS_NAMES:
return AnyType(TypeOfAny.special_form)
# Concatenate is an operator, no need for a proper type
Expand Down Expand Up @@ -1432,7 +1431,11 @@ def visit_raw_expression_type(self, t: RawExpressionType) -> Type:

if self.report_invalid_types:
msg = None
if t.base_type_name in ("builtins.int", "builtins.bool"):
if (
t.base_type_name in ("builtins.int", "builtins.bool")
or mypy.options._based
and t.base_type_name in ("builtins.float", "builtins.complex")
):
if not self.options.bare_literals:
# The only time it makes sense to use an int or bool is inside of
# a literal type.
Expand All @@ -1453,7 +1456,12 @@ def visit_raw_expression_type(self, t: RawExpressionType) -> Type:
self.fail(msg, t, code=codes.VALID_TYPE)
if t.note is not None:
self.note(t.note, t, code=codes.VALID_TYPE)
if t.base_type_name in ("builtins.int", "builtins.bool"):
if t.base_type_name in (
"builtins.int",
"builtins.bool",
"builtins.float",
"builtins.complex",
):
v = t.literal_value
assert v is not None
result = LiteralType(
Expand Down
2 changes: 2 additions & 0 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,8 @@ def _remove_redundant_intersection_items(items: list[Type], keep_erased: bool) -
if inner_i in removed:
continue
proper_inner = get_proper_type(items[inner_i])
# hacky: we check this one first, because it's more likely that the value on the left
# has a last known value/metadata/extra args
if is_proper_subtype(
proper_outer, proper_inner, keep_erased_types=keep_erased, ignore_promotions=True
):
Expand Down
10 changes: 7 additions & 3 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
#
# Note: Float values are only used internally. They are not accepted within
# Literal[...].
LiteralValue: _TypeAlias = Union[int, str, bool, float]
LiteralValue: _TypeAlias = Union[int, str, bool, float, complex]


# If we only import type_visitor in the middle of the file, mypy
Expand Down Expand Up @@ -3137,14 +3137,18 @@ def value_repr(self) -> str:
def serialize(self) -> JsonDict | str:
return {
".class": "LiteralType",
"value": self.value,
"value": self.value if not isinstance(self.value, complex) else str(self.value),
"fallback": self.fallback.serialize(),
}

@classmethod
def deserialize(cls, data: JsonDict) -> LiteralType:
assert data[".class"] == "LiteralType"
return LiteralType(value=data["value"], fallback=Instance.deserialize(data["fallback"]))
fallback = Instance.deserialize(data["fallback"])
value = data["value"]
if fallback.type_ref == "builtins.complex":
value = complex(value)
return LiteralType(value=value, fallback=fallback)

def is_singleton_type(self) -> bool:
return self.is_enum_literal() or isinstance(self.value, bool)
Expand Down
12 changes: 6 additions & 6 deletions mypy/typeshed/stdlib/operator.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -183,16 +183,16 @@ if sys.version_info >= (3, 11):
@final
class attrgetter(Generic[_T_co]):
@overload
def __new__(cls, attr: str, /) -> attrgetter[Any]: ...
def __new__(cls, attr: str, /) -> attrgetter[object]: ...
@overload
def __new__(cls, attr: str, attr2: str, /) -> attrgetter[tuple[Any, Any]]: ...
def __new__(cls, attr: str, attr2: str, /) -> attrgetter[tuple[object, object]]: ...
@overload
def __new__(cls, attr: str, attr2: str, attr3: str, /) -> attrgetter[tuple[Any, Any, Any]]: ...
def __new__(cls, attr: str, attr2: str, attr3: str, /) -> attrgetter[tuple[object, object, object]]: ...
@overload
def __new__(cls, attr: str, attr2: str, attr3: str, attr4: str, /) -> attrgetter[tuple[Any, Any, Any, Any]]: ...
def __new__(cls, attr: str, attr2: str, attr3: str, attr4: str, /) -> attrgetter[tuple[object, object, object, object]]: ...
@overload
def __new__(cls, attr: str, /, *attrs: str) -> attrgetter[tuple[Any, ...]]: ...
def __call__(self, obj: Any, /) -> _T_co: ...
def __new__(cls, attr: str, /, *attrs: str) -> attrgetter[tuple[object, ...]]: ...
def __call__(self, obj: object, /) -> _T_co: ...

@final
class itemgetter(Generic[_T_co]):
Expand Down
Loading
Loading