Skip to content

Commit ee80823

Browse files
committed
Add typing
1 parent ff54160 commit ee80823

File tree

5 files changed

+166
-106
lines changed

5 files changed

+166
-106
lines changed

mockito/invocation.py

Lines changed: 69 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
1919
# THE SOFTWARE.
2020

21+
from __future__ import annotations
2122
from abc import ABC
2223
import os
2324
import inspect
@@ -28,7 +29,12 @@
2829
from . import verification as verificationModule
2930
from .utils import contains_strict
3031

31-
from typing import Any, Callable, Deque, Dict, Tuple
32+
from typing import TYPE_CHECKING
33+
34+
if TYPE_CHECKING:
35+
from typing import Any, Callable, NoReturn, Self, TypeVar, TYPE_CHECKING
36+
from .mocking import Mock
37+
T = TypeVar('T')
3238

3339

3440
class InvocationError(AttributeError):
@@ -45,15 +51,15 @@ class AnswerError(AttributeError):
4551

4652

4753
class Invocation(object):
48-
def __init__(self, mock, method_name):
54+
def __init__(self, mock: Mock, method_name: str) -> None:
4955
self.mock = mock
5056
self.method_name = method_name
5157
self.strict = mock.strict
5258

53-
self.params: Tuple[Any, ...] = ()
54-
self.named_params: Dict[str, Any] = {}
59+
self.params: tuple[Any, ...] = ()
60+
self.named_params: dict[str, Any] = {}
5561

56-
def _remember_params(self, params, named_params):
62+
def _remember_params(self, params: tuple, named_params: dict) -> None:
5763
self.params = params
5864
self.named_params = named_params
5965

@@ -68,27 +74,29 @@ def __repr__(self):
6874

6975

7076
class RealInvocation(Invocation, ABC):
71-
def __init__(self, mock, method_name):
77+
def __init__(self, mock: Mock, method_name: str) -> None:
7278
super(RealInvocation, self).__init__(mock, method_name)
7379
self.verified = False
7480
self.verified_inorder = False
7581

7682

7783
class RememberedInvocation(RealInvocation):
78-
def ensure_mocked_object_has_method(self, method_name):
84+
def ensure_mocked_object_has_method(self, method_name: str) -> None:
7985
if not self.mock.has_method(method_name):
8086
raise InvocationError(
8187
"You tried to call a method '%s' the object (%s) doesn't "
8288
"have." % (method_name, self.mock.mocked_obj))
8389

84-
def ensure_signature_matches(self, method_name, args, kwargs):
90+
def ensure_signature_matches(
91+
self, method_name: str, args: tuple, kwargs: dict
92+
) -> None:
8593
sig = self.mock.get_signature(method_name)
8694
if not sig:
8795
return
8896

8997
signature.match_signature(sig, args, kwargs)
9098

91-
def __call__(self, *params, **named_params):
99+
def __call__(self, *params: Any, **named_params: Any) -> Any | None:
92100
if self.mock.eat_self(self.method_name):
93101
params_without_first_arg = params[1:]
94102
else:
@@ -141,7 +149,7 @@ class RememberedProxyInvocation(RealInvocation):
141149
142150
Calls method on original object and returns it's return value.
143151
"""
144-
def __call__(self, *params, **named_params):
152+
def __call__(self, *params: Any, **named_params: Any) -> Any:
145153
self._remember_params(params, named_params)
146154
self.mock.remember(self)
147155
obj = self.mock.spec
@@ -174,7 +182,7 @@ def compare(p1, p2):
174182
return False
175183
return True
176184

177-
def capture_arguments(self, invocation):
185+
def capture_arguments(self, invocation: RealInvocation) -> None:
178186
"""Capture arguments of `invocation` into "capturing" matchers of self.
179187
180188
This is used in conjunction with "capturing" matchers like
@@ -204,7 +212,7 @@ def capture_arguments(self, invocation):
204212
p1.capture_value(p2)
205213

206214

207-
def _remember_params(self, params, named_params):
215+
def _remember_params(self, params: tuple, named_params: dict) -> None:
208216
if (
209217
contains_strict(params, Ellipsis)
210218
and (params[-1] is not Ellipsis or named_params)
@@ -231,7 +239,7 @@ def wrap(p):
231239
# Note: matches(a, b) does not imply matches(b, a) because
232240
# the left side might contain wildcards (like Ellipsis) or matchers.
233241
# In its current form the right side is a concrete call signature.
234-
def matches(self, invocation): # noqa: C901 (too complex)
242+
def matches(self, invocation: Invocation) -> bool: # noqa: C901, E501 (too complex)
235243
if self.method_name != invocation.method_name:
236244
return False
237245

@@ -294,11 +302,16 @@ class VerifiableInvocation(MatchingInvocation):
294302
call. But the `__call__` is essentially virtual and can contain
295303
placeholders and matchers.
296304
"""
297-
def __init__(self, mock, method_name, verification):
305+
def __init__(
306+
self,
307+
mock: Mock,
308+
method_name: str,
309+
verification: verificationModule.VerificationMode
310+
) -> None:
298311
super(VerifiableInvocation, self).__init__(mock, method_name)
299312
self.verification = verification
300313

301-
def __call__(self, *params, **named_params):
314+
def __call__(self, *params: Any, **named_params: Any) -> None:
302315
self._remember_params(params, named_params)
303316
matched_invocations = []
304317
for invocation in self.mock.invocations:
@@ -321,7 +334,9 @@ def __call__(self, *params, **named_params):
321334
stub.allow_zero_invocations = True
322335

323336

324-
def verification_has_lower_bound_of_zero(verification):
337+
def verification_has_lower_bound_of_zero(
338+
verification: verificationModule.VerificationMode | None
339+
) -> bool:
325340
if (
326341
isinstance(verification, verificationModule.Times)
327342
and verification.wanted_count == 0
@@ -372,7 +387,13 @@ class StubbedInvocation(MatchingInvocation):
372387
there is no "new" keyword in Python.)
373388
374389
"""
375-
def __init__(self, mock, method_name, verification=None, strict=None):
390+
def __init__(
391+
self,
392+
mock: Mock,
393+
method_name: str,
394+
verification: verificationModule.VerificationMode | None = None,
395+
strict: bool | None = None
396+
) -> None:
376397
super(StubbedInvocation, self).__init__(mock, method_name)
377398

378399
#: Holds the verification set up via `expect`.
@@ -391,26 +412,25 @@ def __init__(self, mock, method_name, verification=None, strict=None):
391412

392413
#: Set if `verifyStubbedInvocationsAreUsed` should pass, regardless
393414
#: of any factual invocation. E.g. set by `expect(..., times=0)`
394-
if verification_has_lower_bound_of_zero(verification):
395-
self.allow_zero_invocations = True
396-
else:
397-
self.allow_zero_invocations = False
398-
415+
self.allow_zero_invocations: bool = \
416+
verification_has_lower_bound_of_zero(verification)
399417

400-
def ensure_mocked_object_has_method(self, method_name):
418+
def ensure_mocked_object_has_method(self, method_name: str) -> None:
401419
if not self.mock.has_method(method_name):
402420
raise InvocationError(
403421
"You tried to stub a method '%s' the object (%s) doesn't "
404422
"have." % (method_name, self.mock.mocked_obj))
405423

406-
def ensure_signature_matches(self, method_name, args, kwargs):
424+
def ensure_signature_matches(
425+
self, method_name: str, args: tuple, kwargs: dict
426+
) -> None:
407427
sig = self.mock.get_signature(method_name)
408428
if not sig:
409429
return
410430

411431
signature.match_signature_allowing_placeholders(sig, args, kwargs)
412432

413-
def __call__(self, *params, **named_params):
433+
def __call__(self, *params: Any, **named_params: Any) -> AnswerSelector:
414434
if self.strict:
415435
self.ensure_mocked_object_has_method(self.method_name)
416436
self.ensure_signature_matches(
@@ -421,13 +441,13 @@ def __call__(self, *params, **named_params):
421441
self.mock.finish_stubbing(self)
422442
return AnswerSelector(self)
423443

424-
def forget_self(self):
444+
def forget_self(self) -> None:
425445
self.mock.forget_stubbed_invocation(self)
426446

427-
def add_answer(self, answer):
447+
def add_answer(self, answer: Callable) -> None:
428448
self.answers.add(answer)
429449

430-
def answer_first(self, *args, **kwargs):
450+
def answer_first(self, *args: Any, **kwargs: Any) -> Any:
431451
self.used += 1
432452
return self.answers.answer(*args, **kwargs)
433453

@@ -466,54 +486,52 @@ def should_answer(self, invocation: RememberedInvocation) -> None:
466486
# to get verified 'implicitly', on-the-go, so we set this flag here.
467487
invocation.verified = True
468488

469-
def verify(self):
489+
def verify(self) -> None:
470490
if self.verification:
471491
self.verification.verify(self, self.used)
472492

473-
def check_used(self):
493+
def check_used(self) -> None:
474494
if not self.allow_zero_invocations and self.used < len(self.answers):
475495
raise verificationModule.VerificationError(
476496
"\nUnused stub: %s" % self)
477497

478498

479-
def return_(value):
480-
def answer(*args, **kwargs):
499+
def return_(value: T) -> Callable[..., T]:
500+
def answer(*args, **kwargs) -> T:
481501
return value
482502
return answer
483503

484-
def raise_(exception):
485-
def answer(*args, **kwargs):
504+
def raise_(exception: Exception | type[Exception]) -> Callable[..., NoReturn]:
505+
def answer(*args, **kwargs) -> NoReturn:
486506
raise exception
487507
return answer
488508

489-
490-
def discard_self(function):
491-
def function_without_self(*args, **kwargs):
509+
def discard_self(function: Callable[..., T]) -> Callable[..., T]:
510+
def function_without_self(*args, **kwargs) -> T:
492511
args = args[1:]
493512
return function(*args, **kwargs)
494-
495513
return function_without_self
496514

497515

498516
class AnswerSelector(object):
499-
def __init__(self, invocation):
517+
def __init__(self, invocation: StubbedInvocation) -> None:
500518
self.invocation = invocation
501519
self.discard_first_arg = \
502520
invocation.mock.eat_self(invocation.method_name)
503521

504-
def thenReturn(self, *return_values):
522+
def thenReturn(self, *return_values: Any) -> Self:
505523
for return_value in return_values or (None,):
506524
answer = return_(return_value)
507525
self.__then(answer)
508526
return self
509527

510-
def thenRaise(self, *exceptions):
528+
def thenRaise(self, *exceptions: Exception | type[Exception]) -> Self:
511529
for exception in exceptions or (Exception,):
512530
answer = raise_(exception)
513531
self.__then(answer)
514532
return self
515533

516-
def thenAnswer(self, *callables):
534+
def thenAnswer(self, *callables: Callable) -> Self:
517535
if not callables:
518536
raise TypeError("No answer function provided")
519537
for callable in callables:
@@ -523,7 +541,7 @@ def thenAnswer(self, *callables):
523541
self.__then(answer)
524542
return self
525543

526-
def thenCallOriginalImplementation(self):
544+
def thenCallOriginalImplementation(self) -> Self:
527545
answer = self.invocation.mock.get_original_method(
528546
self.invocation.method_name
529547
)
@@ -545,36 +563,36 @@ def thenCallOriginalImplementation(self):
545563
self.__then(answer)
546564
return self
547565

548-
def __then(self, answer):
566+
def __then(self, answer: Callable) -> None:
549567
self.invocation.add_answer(answer)
550568

551-
def __enter__(self):
569+
def __enter__(self) -> None:
552570
pass
553571

554-
def __exit__(self, *exc_info):
572+
def __exit__(self, *exc_info) -> None:
555573
self.invocation.verify()
556574
if os.environ.get("MOCKITO_CONTEXT_MANAGERS_CHECK_USAGE", "1") == "1":
557575
self.invocation.check_used()
558576
self.invocation.forget_self()
559577

560578

561579
class CompositeAnswer(object):
562-
def __init__(self):
580+
def __init__(self) -> None:
563581
#: Container for answers, which are just ordinary callables
564-
self.answers: Deque[Callable] = deque()
582+
self.answers: deque[Callable] = deque()
565583

566584
#: Counter for the maximum answers we ever had
567585
self.answer_count = 0
568586

569-
def __len__(self):
587+
def __len__(self) -> int:
570588
# The minimum is '1' bc we always have a default answer of 'None'
571589
return max(1, self.answer_count)
572590

573-
def add(self, answer):
591+
def add(self, answer: Callable) -> None:
574592
self.answer_count += 1
575593
self.answers.append(answer)
576594

577-
def answer(self, *args, **kwargs):
595+
def answer(self, *args: Any, **kwargs: Any) -> Any:
578596
if len(self.answers) == 0:
579597
return None
580598

mockito/mock_registry.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
1919
# THE SOFTWARE.
2020

21+
from __future__ import annotations
22+
from typing import TYPE_CHECKING
23+
24+
if TYPE_CHECKING:
25+
from .mocking import Mock
2126

2227

2328
class MockRegistry:
@@ -30,26 +35,26 @@ class MockRegistry:
3035
def __init__(self):
3136
self.mocks = IdentityMap()
3237

33-
def register(self, obj, mock):
38+
def register(self, obj: object, mock: Mock) -> None:
3439
self.mocks[obj] = mock
3540

36-
def mock_for(self, obj):
41+
def mock_for(self, obj: object) -> Mock | None:
3742
return self.mocks.get(obj, None)
3843

39-
def unstub(self, obj):
44+
def unstub(self, obj: object) -> None:
4045
try:
4146
mock = self.mocks.pop(obj)
4247
except KeyError:
4348
pass
4449
else:
4550
mock.unstub()
4651

47-
def unstub_all(self):
52+
def unstub_all(self) -> None:
4853
for mock in self.get_registered_mocks():
4954
mock.unstub()
5055
self.mocks.clear()
5156

52-
def get_registered_mocks(self):
57+
def get_registered_mocks(self) -> list[Mock]:
5358
return self.mocks.values()
5459

5560

0 commit comments

Comments
 (0)