Skip to content

Commit 090363a

Browse files
authored
Adds strict delegate types (#297)
* Adds strict delegate types * Fixes CI * Update supports.rst * Update supports.rst * Update supports.rst * Update supports.rst
1 parent 892870c commit 090363a

25 files changed

+351
-498
lines changed

classes/_registry.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,48 @@
1-
from typing import Callable, Dict, NoReturn, Optional
1+
from typing import Callable, Dict, NoReturn, Optional, Tuple
2+
3+
from typing_extensions import Final
24

35
TypeRegistry = Dict[type, Callable]
46

7+
#: We use this to exclude `None` as a default value for `exact_type`.
8+
DefaultValue: Final = type('DefaultValueType', (object,), {})
9+
10+
#: Used both in runtime and during mypy type-checking.
11+
INVALID_ARGUMENTS_MSG: Final = (
12+
'Only a single argument can be applied to `.instance`'
13+
)
14+
515

616
def choose_registry( # noqa: WPS211
717
# It has multiple arguments, but I don't see an easy and performant way
818
# to refactor it: I don't want to create extra structures
919
# and I don't want to create a class with methods.
10-
typ: type,
11-
is_protocol: bool,
12-
delegate: Optional[type],
20+
exact_type: Optional[type],
21+
protocol: type,
22+
delegate: type,
1323
delegates: TypeRegistry,
14-
instances: TypeRegistry,
24+
exact_types: TypeRegistry,
1525
protocols: TypeRegistry,
16-
) -> TypeRegistry:
26+
) -> Tuple[TypeRegistry, type]:
1727
"""
1828
Returns the appropriate registry to store the passed type.
1929
2030
It depends on how ``instance`` method is used and also on the type itself.
2131
"""
22-
if is_protocol and delegate is not None:
23-
raise ValueError('Both `is_protocol` and `delegate` are passed')
32+
passed_args = list(filter(
33+
_is_not_default_argument_value,
34+
(exact_type, protocol, delegate),
35+
))
36+
if not passed_args:
37+
raise ValueError('At least one argument to `.instance` is required')
38+
if len(passed_args) > 1:
39+
raise ValueError(INVALID_ARGUMENTS_MSG)
2440

25-
if is_protocol:
26-
return protocols
27-
elif delegate is not None:
28-
return delegates
29-
return instances
41+
if _is_not_default_argument_value(delegate):
42+
return delegates, delegate
43+
elif _is_not_default_argument_value(protocol):
44+
return protocols, protocol
45+
return exact_types, exact_type if exact_type is not None else type(None)
3046

3147

3248
def default_implementation(instance, *args, **kwargs) -> NoReturn:
@@ -36,3 +52,7 @@ def default_implementation(instance, *args, **kwargs) -> NoReturn:
3652
type(instance).__qualname__,
3753
),
3854
)
55+
56+
57+
def _is_not_default_argument_value(arg: Optional[type]) -> bool:
58+
return arg is not DefaultValue

classes/_typeclass.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,14 @@
7272
We also support protocols. It has the same limitation as ``Generic`` types.
7373
It is also dispatched after all regular instances are checked.
7474
75-
To work with protocols, one needs to pass ``is_protocol`` flag to instance:
75+
To work with protocols, one needs
76+
to pass ``protocol`` named argument to instance:
7677
7778
.. code:: python
7879
7980
>>> from typing import Sequence
8081
81-
>>> @example.instance(Sequence, is_protocol=True)
82+
>>> @example.instance(protocol=Sequence)
8283
... def _sequence_example(instance: Sequence) -> str:
8384
... return ','.join(str(item) for item in instance)
8485
@@ -99,7 +100,7 @@
99100
>>> class CustomProtocol(Protocol):
100101
... field: str
101102
102-
>>> @example.instance(CustomProtocol, is_protocol=True)
103+
>>> @example.instance(protocol=CustomProtocol)
103104
... def _custom_protocol_example(instance: CustomProtocol) -> str:
104105
... return instance.field
105106
@@ -131,6 +132,7 @@
131132
from typing_extensions import TypeGuard, final
132133

133134
from classes._registry import (
135+
DefaultValue,
134136
TypeRegistry,
135137
choose_registry,
136138
default_implementation,
@@ -313,7 +315,7 @@ class _TypeClass( # noqa: WPS214
313315

314316
# Registry:
315317
'_delegates',
316-
'_instances',
318+
'_exact_types',
317319
'_protocols',
318320

319321
# Cache:
@@ -362,7 +364,7 @@ def __init__(
362364

363365
# Registries:
364366
self._delegates: TypeRegistry = {}
365-
self._instances: TypeRegistry = {}
367+
self._exact_types: TypeRegistry = {}
366368
self._protocols: TypeRegistry = {}
367369

368370
# Cache parts:
@@ -382,8 +384,9 @@ def __call__(
382384
383385
The resolution order is the following:
384386
385-
1. Exact types that are passed as ``.instance`` arguments
386-
2. Protocols that are passed with ``is_protocol=True``
387+
1. Delegates passed with ``delegate=``
388+
2. Exact types that are passed as ``.instance`` arguments
389+
3. Protocols that are passed with ``protocol=``
387390
388391
We don't guarantee the order of types inside groups.
389392
Use correct types, do not rely on our order.
@@ -480,7 +483,7 @@ def supports(
480483
481484
>>> from typing import Sized
482485
483-
>>> @example.instance(Sized, is_protocol=True)
486+
>>> @example.instance(protocol=Sized)
484487
... def _example_sized(instance: Sized) -> str:
485488
... return 'Size is {0}'.format(len(instance))
486489
@@ -523,18 +526,16 @@ def supports(
523526

524527
def instance(
525528
self,
526-
type_argument: Optional[_NewInstanceType],
529+
exact_type: Optional[_NewInstanceType] = DefaultValue, # type: ignore
527530
*,
528-
# TODO: at one point I would like to remove `is_protocol`
529-
# and make this function decide whether this type is protocol or not.
530-
is_protocol: bool = False,
531-
delegate: Optional[type] = None,
531+
protocol: type = DefaultValue,
532+
delegate: type = DefaultValue,
532533
) -> '_TypeClassInstanceDef[_NewInstanceType, _TypeClassType]':
533534
"""
534535
We use this method to store implementation for each specific type.
535536
536537
Args:
537-
is_protocol: required when passing protocols.
538+
protocol: required when passing protocols.
538539
delegate: required when using delegate types, for example,
539540
when working with concrete generics like ``List[str]``.
540541
@@ -543,7 +544,8 @@ def instance(
543544
544545
.. note::
545546
546-
``is_protocol`` and ``delegate`` are mutually exclusive.
547+
``exact_type``, ``protocol``, and ``delegate``
548+
are mutually exclusive. Only one argument can be passed.
547549
548550
We don't use ``@overload`` decorator here
549551
(which makes our ``mypy`` plugin even more complex)
@@ -558,7 +560,14 @@ def instance(
558560
# Then, we have a regular `type_argument`. It is used for most types.
559561
# Lastly, we have `type(None)` to handle cases
560562
# when we want to register `None` as a type / singleton value.
561-
typ = delegate or type_argument or type(None)
563+
registry, typ = choose_registry(
564+
exact_type=exact_type,
565+
protocol=protocol,
566+
delegate=delegate,
567+
exact_types=self._exact_types,
568+
protocols=self._protocols,
569+
delegates=self._delegates,
570+
)
562571

563572
# That's how we check for generics,
564573
# generics that look like `List[int]` or `set[T]` will fail this check,
@@ -567,19 +576,9 @@ def instance(
567576
isinstance(object(), typ)
568577

569578
def decorator(implementation):
570-
container = choose_registry(
571-
typ=typ,
572-
is_protocol=is_protocol,
573-
delegate=delegate,
574-
delegates=self._delegates,
575-
instances=self._instances,
576-
protocols=self._protocols,
577-
)
578-
container[typ] = implementation
579-
579+
registry[typ] = implementation
580580
self._dispatch_cache.clear()
581581
return implementation
582-
583582
return decorator
584583

585584
def _dispatch(self, instance, instance_type: type) -> Optional[Callable]:
@@ -591,15 +590,15 @@ def _dispatch(self, instance, instance_type: type) -> Optional[Callable]:
591590
2. By matching protocols
592591
3. By its ``mro``
593592
"""
594-
implementation = self._instances.get(instance_type, None)
593+
implementation = self._exact_types.get(instance_type, None)
595594
if implementation is not None:
596595
return implementation
597596

598597
for protocol, callback in self._protocols.items():
599598
if isinstance(instance, protocol):
600599
return callback
601600

602-
return _find_impl(instance_type, self._instances)
601+
return _find_impl(instance_type, self._exact_types)
603602

604603
def _dispatch_delegate(self, instance) -> Optional[Callable]:
605604
for delegate, callback in self._delegates.items():

classes/contrib/mypy/features/typeclass.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from classes.contrib.mypy.typeops import (
1818
call_signatures,
1919
fallback,
20-
instance_args,
20+
instance_type_args,
2121
mro,
2222
type_loader,
2323
)
@@ -78,7 +78,7 @@ def __call__(self, ctx: FunctionContext) -> MypyType:
7878
assert isinstance(defn, CallableType)
7979
assert defn.definition
8080

81-
instance_args.mutate_typeclass_def(
81+
instance_type_args.mutate_typeclass_def(
8282
typeclass=ctx.default_return_type,
8383
definition_fullname=defn.definition.fullname,
8484
ctx=ctx,
@@ -125,7 +125,7 @@ def __call__(self, ctx: MethodContext) -> MypyType:
125125
assert isinstance(ctx.default_return_type, Instance)
126126
assert isinstance(ctx.context, Decorator)
127127

128-
instance_args.mutate_typeclass_def(
128+
instance_type_args.mutate_typeclass_def(
129129
typeclass=ctx.default_return_type,
130130
definition_fullname=ctx.context.func.fullname,
131131
ctx=ctx,
@@ -135,6 +135,7 @@ def __call__(self, ctx: MethodContext) -> MypyType:
135135
typeclass=ctx.default_return_type,
136136
ctx=ctx,
137137
)
138+
138139
if isinstance(ctx.default_return_type.args[2], Instance):
139140
validate_associated_type.check_type(
140141
associated_type=ctx.default_return_type.args[2],
@@ -161,7 +162,7 @@ def instance_return_type(ctx: MethodContext) -> MypyType:
161162
else:
162163
passed_types.append(UninhabitedType())
163164

164-
instance_args.mutate_typeclass_instance_def(
165+
instance_type_args.mutate_typeclass_instance_def(
165166
ctx.default_return_type,
166167
ctx=ctx,
167168
typeclass=ctx.type,
@@ -237,12 +238,18 @@ def _load_typeclass(
237238
return typeclass, typeclass_ref.args[3].value
238239

239240
def _run_validation(self, instance_context: InstanceContext) -> bool:
241+
# When delegate is passed, we use it instead of instance type.
242+
# Why? Because `delegate` can repre
243+
instance_or_delegate = (
244+
instance_context.inferred_args.delegate
245+
if instance_context.inferred_args.delegate is not None
246+
else instance_context.instance_type
247+
)
240248
# We need to add `Supports` metadata before typechecking,
241249
# because it will affect type hierarchies.
242250
metadata = mro.MetadataInjector(
243251
associated_type=instance_context.associated_type,
244-
instance_type=instance_context.instance_type,
245-
delegate=instance_context.delegate,
252+
instance_type=instance_or_delegate,
246253
ctx=instance_context.ctx,
247254
)
248255
metadata.add_supports_metadata()
@@ -262,7 +269,7 @@ def _add_new_instance_type(
262269
ctx: MethodContext,
263270
) -> None:
264271
typeclass.args = (
265-
instance_args.add_unique(new_type, typeclass.args[0]),
272+
instance_type_args.add_unique(new_type, typeclass.args[0]),
266273
*typeclass.args[1:],
267274
)
268275

0 commit comments

Comments
 (0)