Skip to content

Commit 47350d2

Browse files
authored
Merge pull request #3488 from nachomaiz/sprite-typing-enhancements
Typing enhancements for the `pygame.sprite` module
2 parents 2e636c1 + ac03a98 commit 47350d2

File tree

1 file changed

+61
-96
lines changed

1 file changed

+61
-96
lines changed

buildconfig/stubs/pygame/sprite.pyi

Lines changed: 61 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import types
23
from collections.abc import Callable, Iterable, Iterator
34
from typing import (
@@ -10,17 +11,48 @@ from typing import (
1011
Union,
1112
)
1213

14+
# use typing_extensions for compatibility with older Python versions
15+
if sys.version_info >= (3, 13):
16+
from warnings import deprecated
17+
else:
18+
from typing_extensions import deprecated
19+
20+
if sys.version_info >= (3, 11):
21+
from typing import Self
22+
else:
23+
from typing_extensions import Self
24+
1325
from pygame.mask import Mask
1426
from pygame.rect import FRect, Rect
1527
from pygame.surface import Surface
1628
from pygame.typing import Point, RectLike
17-
from typing_extensions import deprecated # added in 3.13
29+
30+
# define some useful protocols first, which sprite functions accept
31+
# sprite functions don't need all sprite attributes to be present in the
32+
# arguments passed, they only use a few which are marked in the below protocols
33+
class _HasRect(Protocol):
34+
@property
35+
def rect(self) -> Optional[Union[FRect, Rect]]: ...
36+
37+
# image in addition to rect
38+
class _HasImageAndRect(_HasRect, Protocol):
39+
@property
40+
def image(self) -> Optional[Surface]: ...
41+
42+
# mask in addition to rect
43+
class _HasMaskAndRect(_HasRect, Protocol):
44+
mask: Mask
45+
46+
# radius in addition to rect
47+
class _HasRadiusAndRect(_HasRect, Protocol):
48+
radius: float
1849

1950
# non-generic Group, used in Sprite
20-
_Group = AbstractGroup[_SpriteSupportsGroup]
51+
_Group = AbstractGroup[Any]
2152

2253
# protocol helps with structural subtyping for typevars in sprite group generics
23-
class _SupportsSprite(Protocol):
54+
# and allows the use of any class with the required attributes and methods
55+
class _SupportsSprite(_HasImageAndRect, Protocol):
2456
@property
2557
def image(self) -> Optional[Surface]: ...
2658
@image.setter
@@ -33,7 +65,6 @@ class _SupportsSprite(Protocol):
3365
def layer(self) -> int: ...
3466
@layer.setter
3567
def layer(self, value: int) -> None: ...
36-
def __init__(self, *groups: _Group) -> None: ...
3768
def add_internal(self, group: _Group) -> None: ...
3869
def remove_internal(self, group: _Group) -> None: ...
3970
def update(self, *args: Any, **kwargs: Any) -> None: ...
@@ -75,10 +106,10 @@ class Sprite(_SupportsSprite):
75106
def remove(self, *groups: _Group) -> None: ...
76107
def kill(self) -> None: ...
77108
def alive(self) -> bool: ...
78-
def groups(self) -> list[_Group]: ...
109+
def groups(self) -> list[AbstractGroup[_SupportsSprite]]: ...
79110

80111
# concrete dirty sprite implementation class
81-
class DirtySprite(_SupportsDirtySprite):
112+
class DirtySprite(Sprite, _SupportsDirtySprite):
82113
dirty: int
83114
blendmode: int
84115
source_rect: Union[FRect, Rect]
@@ -87,59 +118,14 @@ class DirtySprite(_SupportsDirtySprite):
87118
def _set_visible(self, val: int) -> None: ...
88119
def _get_visible(self) -> int: ...
89120

90-
# used as a workaround for typing.Self because it is added in python 3.11
91-
_TGroup = TypeVar("_TGroup", bound=AbstractGroup)
92-
93-
# define some useful protocols first, which sprite functions accept
94-
# sprite functions don't need all sprite attributes to be present in the
95-
# arguments passed, they only use a few which are marked in the below protocols
96-
class _HasRect(Protocol):
97-
@property
98-
def rect(self) -> Optional[Union[FRect, Rect]]: ...
99-
100-
# image in addition to rect
101-
class _HasImageAndRect(_HasRect, Protocol):
102-
@property
103-
def image(self) -> Optional[Surface]: ...
104-
105-
# mask in addition to rect
106-
class _HasMaskAndRect(_HasRect, Protocol):
107-
mask: Mask
108-
109-
# radius in addition to rect
110-
class _HasRadiusAndRect(_HasRect, Protocol):
111-
radius: float
112-
113-
class _SpriteSupportsGroup(_SupportsSprite, _HasImageAndRect, Protocol): ...
114-
class _DirtySpriteSupportsGroup(_SupportsDirtySprite, _HasImageAndRect, Protocol): ...
115-
116-
# typevar bound to Sprite, _SpriteSupportsGroup Protocol ensures sprite
121+
# typevar bound to Sprite, _SupportsSprite Protocol ensures sprite
117122
# subclass passed to group has image and rect attributes
118-
_TSprite = TypeVar("_TSprite", bound=_SpriteSupportsGroup)
119-
_TSprite2 = TypeVar("_TSprite2", bound=_SpriteSupportsGroup)
120-
121-
# almost the same as _TSprite but bound to DirtySprite
122-
_TDirtySprite = TypeVar("_TDirtySprite", bound=_DirtySpriteSupportsGroup)
123+
_TSprite = TypeVar("_TSprite", bound=_SupportsSprite)
124+
_TSprite2 = TypeVar("_TSprite2", bound=_SupportsSprite)
125+
_TDirtySprite = TypeVar("_TDirtySprite", bound=_SupportsDirtySprite)
123126

124-
# Below code demonstrates the advantages of the _SpriteSupportsGroup protocol
125-
126-
# typechecker should error, regular Sprite does not support Group.draw due to
127-
# missing image and rect attributes
128-
# a = Group(Sprite())
129-
130-
# typechecker should error, other Sprite attributes are also needed for Group
131-
# class MySprite:
132-
# image: Surface
133-
# rect: Rect
134-
#
135-
# b = Group(MySprite())
136-
137-
# typechecker should pass
138-
# class MySprite(Sprite):
139-
# image: Surface
140-
# rect: Rect
141-
#
142-
# b = Group(MySprite())
127+
# typevar for sprite or iterable of sprites, used in Group init, add and remove
128+
_SpriteOrIterable = Union[_TSprite, Iterable[_SpriteOrIterable[_TSprite]]]
143129

144130
class AbstractGroup(Generic[_TSprite]):
145131
spritedict: dict[_TSprite, Optional[Union[FRect, Rect]]]
@@ -153,17 +139,11 @@ class AbstractGroup(Generic[_TSprite]):
153139
def add_internal(self, sprite: _TSprite, layer: None = None) -> None: ...
154140
def remove_internal(self, sprite: _TSprite) -> None: ...
155141
def has_internal(self, sprite: _TSprite) -> bool: ...
156-
def copy(self: _TGroup) -> _TGroup: ... # typing.Self is py3.11+
142+
def copy(self) -> Self: ...
157143
def sprites(self) -> list[_TSprite]: ...
158-
def add(
159-
self, *sprites: Union[_TSprite, AbstractGroup[_TSprite], Iterable[_TSprite]]
160-
) -> None: ...
161-
def remove(
162-
self, *sprites: Union[_TSprite, AbstractGroup[_TSprite], Iterable[_TSprite]]
163-
) -> None: ...
164-
def has(
165-
self, *sprites: Union[_TSprite, AbstractGroup[_TSprite], Iterable[_TSprite]]
166-
) -> bool: ...
144+
def add(self, *sprites: _SpriteOrIterable[_TSprite]) -> None: ...
145+
def remove(self, *sprites: _SpriteOrIterable[_TSprite]) -> None: ...
146+
def has(self, *sprites: _SpriteOrIterable[_TSprite]) -> bool: ...
167147
def update(self, *args: Any, **kwargs: Any) -> None: ...
168148
def draw(
169149
self, surface: Surface, bgd: Optional[Surface] = None, special_flags: int = 0
@@ -176,16 +156,14 @@ class AbstractGroup(Generic[_TSprite]):
176156
def empty(self) -> None: ...
177157

178158
class Group(AbstractGroup[_TSprite]):
179-
def __init__(
180-
self, *sprites: Union[_TSprite, AbstractGroup[_TSprite], Iterable[_TSprite]]
181-
) -> None: ...
159+
def __init__(self, *sprites: _SpriteOrIterable[_TSprite]) -> None: ...
182160

183161
# these are aliased in the code too
184162
@deprecated("Use `pygame.sprite.Group` instead")
185-
class RenderPlain(Group): ...
163+
class RenderPlain(Group[_TSprite]): ...
186164

187165
@deprecated("Use `pygame.sprite.Group` instead")
188-
class RenderClear(Group): ...
166+
class RenderClear(Group[_TSprite]): ...
189167

190168
class RenderUpdates(Group[_TSprite]): ...
191169

@@ -194,23 +172,9 @@ class OrderedUpdates(RenderUpdates[_TSprite]): ...
194172

195173
class LayeredUpdates(AbstractGroup[_TSprite]):
196174
def __init__(
197-
self,
198-
*sprites: Union[
199-
_TSprite,
200-
AbstractGroup[_TSprite],
201-
Iterable[Union[_TSprite, AbstractGroup[_TSprite]]],
202-
],
203-
**kwargs: Any,
204-
) -> None: ...
205-
def add(
206-
self,
207-
*sprites: Union[
208-
_TSprite,
209-
AbstractGroup[_TSprite],
210-
Iterable[Union[_TSprite, AbstractGroup[_TSprite]]],
211-
],
212-
**kwargs: Any,
175+
self, *sprites: _SpriteOrIterable[_TSprite], **kwargs: Any
213176
) -> None: ...
177+
def add(self, *sprites: _SpriteOrIterable[_TSprite], **kwargs: Any) -> None: ...
214178
def get_sprites_at(self, pos: Point) -> list[_TSprite]: ...
215179
def get_sprite(self, idx: int) -> _TSprite: ...
216180
def remove_sprites_of_layer(self, layer_nr: int) -> list[_TSprite]: ...
@@ -226,7 +190,6 @@ class LayeredUpdates(AbstractGroup[_TSprite]):
226190
def switch_layer(self, layer1_nr: int, layer2_nr: int) -> None: ...
227191

228192
class LayeredDirty(LayeredUpdates[_TDirtySprite]):
229-
def __init__(self, *sprites: _TDirtySprite, **kwargs: Any) -> None: ...
230193
def draw(
231194
self,
232195
surface: Surface,
@@ -238,9 +201,7 @@ class LayeredDirty(LayeredUpdates[_TDirtySprite]):
238201
def repaint_rect(self, screen_rect: RectLike) -> None: ...
239202
def set_clip(self, screen_rect: Optional[RectLike] = None) -> None: ...
240203
def get_clip(self) -> Union[FRect, Rect]: ...
241-
def set_timing_threshold(
242-
self, time_ms: SupportsFloat
243-
) -> None: ... # This actually accept any value
204+
def set_timing_threshold(self, time_ms: SupportsFloat) -> None: ...
244205
@deprecated(
245206
"since 2.1.1. Use `pygame.sprite.LayeredDirty.set_timing_threshold` instead"
246207
)
@@ -279,11 +240,15 @@ _SupportsCollideMask = Union[_HasImageAndRect, _HasMaskAndRect]
279240
def collide_mask(
280241
left: _SupportsCollideMask, right: _SupportsCollideMask
281242
) -> Optional[tuple[int, int]]: ...
243+
244+
# _HasRect typevar for sprite collide functions
245+
_THasRect = TypeVar("_THasRect", bound=_HasRect)
246+
282247
def spritecollide(
283-
sprite: _HasRect,
248+
sprite: _THasRect,
284249
group: AbstractGroup[_TSprite],
285250
dokill: bool,
286-
collided: Optional[Callable[[_TSprite, _TSprite2], Any]] = None,
251+
collided: Optional[Callable[[_THasRect, _TSprite], Any]] = None,
287252
) -> list[_TSprite]: ...
288253
def groupcollide(
289254
groupa: AbstractGroup[_TSprite],
@@ -293,7 +258,7 @@ def groupcollide(
293258
collided: Optional[Callable[[_TSprite, _TSprite2], Any]] = None,
294259
) -> dict[_TSprite, list[_TSprite2]]: ...
295260
def spritecollideany(
296-
sprite: _HasRect,
261+
sprite: _THasRect,
297262
group: AbstractGroup[_TSprite],
298-
collided: Optional[Callable[[_TSprite, _TSprite2], Any]] = None,
263+
collided: Optional[Callable[[_THasRect, _TSprite], Any]] = None,
299264
) -> Optional[_TSprite]: ...

0 commit comments

Comments
 (0)