18
18
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
19
# THE SOFTWARE.
20
20
21
+ from __future__ import annotations
21
22
from abc import ABC
22
23
import os
23
24
import inspect
28
29
from . import verification as verificationModule
29
30
from .utils import contains_strict
30
31
31
- from typing import Any , Callable , Deque , Dict , Tuple
32
+ from typing import TYPE_CHECKING
33
+
34
+ if TYPE_CHECKING :
35
+ from typing import Any , Callable , NoReturn , Self , TypeVar , TYPE_CHECKING
36
+ from .mocking import Mock
37
+ T = TypeVar ('T' )
32
38
33
39
34
40
class InvocationError (AttributeError ):
@@ -45,15 +51,15 @@ class AnswerError(AttributeError):
45
51
46
52
47
53
class Invocation (object ):
48
- def __init__ (self , mock , method_name ) :
54
+ def __init__ (self , mock : Mock , method_name : str ) -> None :
49
55
self .mock = mock
50
56
self .method_name = method_name
51
57
self .strict = mock .strict
52
58
53
- self .params : Tuple [Any , ...] = ()
54
- self .named_params : Dict [str , Any ] = {}
59
+ self .params : tuple [Any , ...] = ()
60
+ self .named_params : dict [str , Any ] = {}
55
61
56
- def _remember_params (self , params , named_params ) :
62
+ def _remember_params (self , params : tuple , named_params : dict ) -> None :
57
63
self .params = params
58
64
self .named_params = named_params
59
65
@@ -68,27 +74,29 @@ def __repr__(self):
68
74
69
75
70
76
class RealInvocation (Invocation , ABC ):
71
- def __init__ (self , mock , method_name ) :
77
+ def __init__ (self , mock : Mock , method_name : str ) -> None :
72
78
super (RealInvocation , self ).__init__ (mock , method_name )
73
79
self .verified = False
74
80
self .verified_inorder = False
75
81
76
82
77
83
class RememberedInvocation (RealInvocation ):
78
- def ensure_mocked_object_has_method (self , method_name ) :
84
+ def ensure_mocked_object_has_method (self , method_name : str ) -> None :
79
85
if not self .mock .has_method (method_name ):
80
86
raise InvocationError (
81
87
"You tried to call a method '%s' the object (%s) doesn't "
82
88
"have." % (method_name , self .mock .mocked_obj ))
83
89
84
- def ensure_signature_matches (self , method_name , args , kwargs ):
90
+ def ensure_signature_matches (
91
+ self , method_name : str , args : tuple , kwargs : dict
92
+ ) -> None :
85
93
sig = self .mock .get_signature (method_name )
86
94
if not sig :
87
95
return
88
96
89
97
signature .match_signature (sig , args , kwargs )
90
98
91
- def __call__ (self , * params , ** named_params ) :
99
+ def __call__ (self , * params : Any , ** named_params : Any ) -> Any | None :
92
100
if self .mock .eat_self (self .method_name ):
93
101
params_without_first_arg = params [1 :]
94
102
else :
@@ -141,7 +149,7 @@ class RememberedProxyInvocation(RealInvocation):
141
149
142
150
Calls method on original object and returns it's return value.
143
151
"""
144
- def __call__ (self , * params , ** named_params ) :
152
+ def __call__ (self , * params : Any , ** named_params : Any ) -> Any :
145
153
self ._remember_params (params , named_params )
146
154
self .mock .remember (self )
147
155
obj = self .mock .spec
@@ -174,7 +182,7 @@ def compare(p1, p2):
174
182
return False
175
183
return True
176
184
177
- def capture_arguments (self , invocation ) :
185
+ def capture_arguments (self , invocation : RealInvocation ) -> None :
178
186
"""Capture arguments of `invocation` into "capturing" matchers of self.
179
187
180
188
This is used in conjunction with "capturing" matchers like
@@ -204,7 +212,7 @@ def capture_arguments(self, invocation):
204
212
p1 .capture_value (p2 )
205
213
206
214
207
- def _remember_params (self , params , named_params ) :
215
+ def _remember_params (self , params : tuple , named_params : dict ) -> None :
208
216
if (
209
217
contains_strict (params , Ellipsis )
210
218
and (params [- 1 ] is not Ellipsis or named_params )
@@ -231,7 +239,7 @@ def wrap(p):
231
239
# Note: matches(a, b) does not imply matches(b, a) because
232
240
# the left side might contain wildcards (like Ellipsis) or matchers.
233
241
# In its current form the right side is a concrete call signature.
234
- def matches (self , invocation ) : # noqa: C901 (too complex)
242
+ def matches (self , invocation : Invocation ) -> bool : # noqa: C901, E501 (too complex)
235
243
if self .method_name != invocation .method_name :
236
244
return False
237
245
@@ -294,11 +302,16 @@ class VerifiableInvocation(MatchingInvocation):
294
302
call. But the `__call__` is essentially virtual and can contain
295
303
placeholders and matchers.
296
304
"""
297
- def __init__ (self , mock , method_name , verification ):
305
+ def __init__ (
306
+ self ,
307
+ mock : Mock ,
308
+ method_name : str ,
309
+ verification : verificationModule .VerificationMode
310
+ ) -> None :
298
311
super (VerifiableInvocation , self ).__init__ (mock , method_name )
299
312
self .verification = verification
300
313
301
- def __call__ (self , * params , ** named_params ) :
314
+ def __call__ (self , * params : Any , ** named_params : Any ) -> None :
302
315
self ._remember_params (params , named_params )
303
316
matched_invocations = []
304
317
for invocation in self .mock .invocations :
@@ -321,7 +334,9 @@ def __call__(self, *params, **named_params):
321
334
stub .allow_zero_invocations = True
322
335
323
336
324
- def verification_has_lower_bound_of_zero (verification ):
337
+ def verification_has_lower_bound_of_zero (
338
+ verification : verificationModule .VerificationMode | None
339
+ ) -> bool :
325
340
if (
326
341
isinstance (verification , verificationModule .Times )
327
342
and verification .wanted_count == 0
@@ -372,7 +387,13 @@ class StubbedInvocation(MatchingInvocation):
372
387
there is no "new" keyword in Python.)
373
388
374
389
"""
375
- def __init__ (self , mock , method_name , verification = None , strict = None ):
390
+ def __init__ (
391
+ self ,
392
+ mock : Mock ,
393
+ method_name : str ,
394
+ verification : verificationModule .VerificationMode | None = None ,
395
+ strict : bool | None = None
396
+ ) -> None :
376
397
super (StubbedInvocation , self ).__init__ (mock , method_name )
377
398
378
399
#: Holds the verification set up via `expect`.
@@ -391,26 +412,25 @@ def __init__(self, mock, method_name, verification=None, strict=None):
391
412
392
413
#: Set if `verifyStubbedInvocationsAreUsed` should pass, regardless
393
414
#: of any factual invocation. E.g. set by `expect(..., times=0)`
394
- if verification_has_lower_bound_of_zero (verification ):
395
- self .allow_zero_invocations = True
396
- else :
397
- self .allow_zero_invocations = False
398
-
415
+ self .allow_zero_invocations : bool = \
416
+ verification_has_lower_bound_of_zero (verification )
399
417
400
- def ensure_mocked_object_has_method (self , method_name ) :
418
+ def ensure_mocked_object_has_method (self , method_name : str ) -> None :
401
419
if not self .mock .has_method (method_name ):
402
420
raise InvocationError (
403
421
"You tried to stub a method '%s' the object (%s) doesn't "
404
422
"have." % (method_name , self .mock .mocked_obj ))
405
423
406
- def ensure_signature_matches (self , method_name , args , kwargs ):
424
+ def ensure_signature_matches (
425
+ self , method_name : str , args : tuple , kwargs : dict
426
+ ) -> None :
407
427
sig = self .mock .get_signature (method_name )
408
428
if not sig :
409
429
return
410
430
411
431
signature .match_signature_allowing_placeholders (sig , args , kwargs )
412
432
413
- def __call__ (self , * params , ** named_params ) :
433
+ def __call__ (self , * params : Any , ** named_params : Any ) -> AnswerSelector :
414
434
if self .strict :
415
435
self .ensure_mocked_object_has_method (self .method_name )
416
436
self .ensure_signature_matches (
@@ -421,13 +441,13 @@ def __call__(self, *params, **named_params):
421
441
self .mock .finish_stubbing (self )
422
442
return AnswerSelector (self )
423
443
424
- def forget_self (self ):
444
+ def forget_self (self ) -> None :
425
445
self .mock .forget_stubbed_invocation (self )
426
446
427
- def add_answer (self , answer ) :
447
+ def add_answer (self , answer : Callable ) -> None :
428
448
self .answers .add (answer )
429
449
430
- def answer_first (self , * args , ** kwargs ) :
450
+ def answer_first (self , * args : Any , ** kwargs : Any ) -> Any :
431
451
self .used += 1
432
452
return self .answers .answer (* args , ** kwargs )
433
453
@@ -466,54 +486,52 @@ def should_answer(self, invocation: RememberedInvocation) -> None:
466
486
# to get verified 'implicitly', on-the-go, so we set this flag here.
467
487
invocation .verified = True
468
488
469
- def verify (self ):
489
+ def verify (self ) -> None :
470
490
if self .verification :
471
491
self .verification .verify (self , self .used )
472
492
473
- def check_used (self ):
493
+ def check_used (self ) -> None :
474
494
if not self .allow_zero_invocations and self .used < len (self .answers ):
475
495
raise verificationModule .VerificationError (
476
496
"\n Unused stub: %s" % self )
477
497
478
498
479
- def return_ (value ) :
480
- def answer (* args , ** kwargs ):
499
+ def return_ (value : T ) -> Callable [..., T ] :
500
+ def answer (* args , ** kwargs ) -> T :
481
501
return value
482
502
return answer
483
503
484
- def raise_ (exception ) :
485
- def answer (* args , ** kwargs ):
504
+ def raise_ (exception : Exception | type [ Exception ]) -> Callable [..., NoReturn ] :
505
+ def answer (* args , ** kwargs ) -> NoReturn :
486
506
raise exception
487
507
return answer
488
508
489
-
490
- def discard_self (function ):
491
- def function_without_self (* args , ** kwargs ):
509
+ def discard_self (function : Callable [..., T ]) -> Callable [..., T ]:
510
+ def function_without_self (* args , ** kwargs ) -> T :
492
511
args = args [1 :]
493
512
return function (* args , ** kwargs )
494
-
495
513
return function_without_self
496
514
497
515
498
516
class AnswerSelector (object ):
499
- def __init__ (self , invocation ) :
517
+ def __init__ (self , invocation : StubbedInvocation ) -> None :
500
518
self .invocation = invocation
501
519
self .discard_first_arg = \
502
520
invocation .mock .eat_self (invocation .method_name )
503
521
504
- def thenReturn (self , * return_values ) :
522
+ def thenReturn (self , * return_values : Any ) -> Self :
505
523
for return_value in return_values or (None ,):
506
524
answer = return_ (return_value )
507
525
self .__then (answer )
508
526
return self
509
527
510
- def thenRaise (self , * exceptions ) :
528
+ def thenRaise (self , * exceptions : Exception | type [ Exception ]) -> Self :
511
529
for exception in exceptions or (Exception ,):
512
530
answer = raise_ (exception )
513
531
self .__then (answer )
514
532
return self
515
533
516
- def thenAnswer (self , * callables ) :
534
+ def thenAnswer (self , * callables : Callable ) -> Self :
517
535
if not callables :
518
536
raise TypeError ("No answer function provided" )
519
537
for callable in callables :
@@ -523,7 +541,7 @@ def thenAnswer(self, *callables):
523
541
self .__then (answer )
524
542
return self
525
543
526
- def thenCallOriginalImplementation (self ):
544
+ def thenCallOriginalImplementation (self ) -> Self :
527
545
answer = self .invocation .mock .get_original_method (
528
546
self .invocation .method_name
529
547
)
@@ -545,36 +563,36 @@ def thenCallOriginalImplementation(self):
545
563
self .__then (answer )
546
564
return self
547
565
548
- def __then (self , answer ) :
566
+ def __then (self , answer : Callable ) -> None :
549
567
self .invocation .add_answer (answer )
550
568
551
- def __enter__ (self ):
569
+ def __enter__ (self ) -> None :
552
570
pass
553
571
554
- def __exit__ (self , * exc_info ):
572
+ def __exit__ (self , * exc_info ) -> None :
555
573
self .invocation .verify ()
556
574
if os .environ .get ("MOCKITO_CONTEXT_MANAGERS_CHECK_USAGE" , "1" ) == "1" :
557
575
self .invocation .check_used ()
558
576
self .invocation .forget_self ()
559
577
560
578
561
579
class CompositeAnswer (object ):
562
- def __init__ (self ):
580
+ def __init__ (self ) -> None :
563
581
#: Container for answers, which are just ordinary callables
564
- self .answers : Deque [Callable ] = deque ()
582
+ self .answers : deque [Callable ] = deque ()
565
583
566
584
#: Counter for the maximum answers we ever had
567
585
self .answer_count = 0
568
586
569
- def __len__ (self ):
587
+ def __len__ (self ) -> int :
570
588
# The minimum is '1' bc we always have a default answer of 'None'
571
589
return max (1 , self .answer_count )
572
590
573
- def add (self , answer ) :
591
+ def add (self , answer : Callable ) -> None :
574
592
self .answer_count += 1
575
593
self .answers .append (answer )
576
594
577
- def answer (self , * args , ** kwargs ) :
595
+ def answer (self , * args : Any , ** kwargs : Any ) -> Any :
578
596
if len (self .answers ) == 0 :
579
597
return None
580
598
0 commit comments