11# Copyright (c) NiceBots
22# SPDX-License-Identifier: MIT
3+ from collections .abc import Sequence , Callable
4+ from typing import Literal , overload
5+ from functools import partial
36
47from discord .ext import commands
58
69from src import custom
710from src .database .models import Guild , User
811
9-
10- async def _preload_user (ctx : custom .Context ) -> bool :
12+ async def _preload_user (ctx : custom .Context , prefetch_related : Sequence [str ]) -> Literal [True ]:
1113 """Preload the user object into the context object.
1214
1315 Args:
@@ -20,16 +22,13 @@ async def _preload_user(ctx: custom.Context) -> bool:
2022
2123 """
2224 if isinstance (ctx , custom .ExtContext ):
23- ctx .user_obj = await User .get_or_none (id = ctx .author .id ) if ctx .author else None
25+ ctx .user_obj = await User .get_or_none (id = ctx .author .id ). prefetch_related ( * prefetch_related ) if ctx .author else None
2426 else :
25- ctx .user_obj = await User .get_or_none (id = ctx .user .id ) if ctx .user else None
27+ ctx .user_obj = await User .get_or_none (id = ctx .user .id ). prefetch_related ( * prefetch_related ) if ctx .user else None
2628 return True
2729
2830
29- preload_user = commands .check (_preload_user ) # pyright: ignore [reportArgumentType]
30-
31-
32- async def _preload_guild (ctx : custom .Context ) -> bool :
31+ async def _preload_guild (ctx : custom .Context , prefetch_related : Sequence [str ]) -> Literal [True ]:
3332 """Preload the guild object into the context object.
3433
3534 Args:
@@ -41,14 +40,12 @@ async def _preload_guild(ctx: custom.Context) -> bool:
4140 bool: (True) always.
4241
4342 """
44- ctx .guild_obj = await Guild .get_or_none (id = ctx .guild .id ) if ctx .guild else None
43+ ctx .guild_obj = await Guild .get_or_none (id = ctx .guild .id ). prefetch_related ( * prefetch_related ) if ctx .guild else None
4544 return True
4645
4746
48- preload_guild = commands .check (_preload_guild ) # pyright: ignore [reportArgumentType]
4947
50-
51- async def _preload_or_create_user (ctx : custom .Context ) -> bool :
48+ async def _preload_or_create_user (ctx : custom .Context , prefetch_related : Sequence [str ]) -> Literal [True ]:
5249 """Preload or create the user object into the context object. If the user object does not exist, create it.
5350
5451 Args:
@@ -60,14 +57,16 @@ async def _preload_or_create_user(ctx: custom.Context) -> bool:
6057 bool: (True) always.
6158
6259 """
63- ctx .user_obj , _ = await User .get_or_create (id = ctx .author .id ) if ctx .author else (None , None )
60+ user : User | None
61+ user , _ = await User .get_or_create (id = ctx .author .id ) if ctx .author else (None , None )
62+ if user is not None :
63+ await user .fetch_related (* prefetch_related )
64+ ctx .user_obj = user
6465 return True
6566
6667
67- preload_or_create_user = commands .check (_preload_or_create_user )
68-
6968
70- async def _preload_or_create_guild (ctx : custom .Context ) -> bool :
69+ async def _preload_or_create_guild (ctx : custom .Context , prefetch_related : Sequence [ str ] ) -> Literal [ True ] :
7170 """Preload or create the guild object into the context object. If the guild object does not exist, create it.
7271
7372 Args:
@@ -79,8 +78,43 @@ async def _preload_or_create_guild(ctx: custom.Context) -> bool:
7978 bool: (True) always.
8079
8180 """
82- ctx .guild_obj , _ = await Guild .get_or_create (id = ctx .guild .id ) if ctx .guild else (None , None )
81+ guild : Guild | None
82+ guild , _ = await Guild .get_or_create (id = ctx .guild .id ) if ctx .guild else (None , None )
83+ if guild is not None :
84+ await guild .fetch_related (* prefetch_related )
85+ ctx .guild_obj = guild
8386 return True
8487
8588
86- preload_or_create_guild = commands .check (_preload_or_create_guild ) # pyright: ignore [reportArgumentType]
89+ type PreloadFunction = Callable [[custom .Context , Sequence [str ]], Literal [True ]]
90+
91+ @overload
92+ def preload_x [T ](f : T , preloader : PreloadFunction , prefetch_related : Sequence [str ] | None = None ) -> T :
93+ """When used as a direct decorator: @preload_or_create_user"""
94+ ...
95+
96+ @overload
97+ def preload_x (f : None = None , * , preloader : PreloadFunction , prefetch_related : Sequence [str ] | None = None ) -> Callable [[T ], T ]:
98+ """When used with arguments: @preload_or_create_user(prefetch_related=[...])"""
99+ ...
100+
101+ def preload_x [T ](
102+ f : Callable [[T ], T ] | None = None ,
103+ * ,
104+ preloader : PreloadFunction ,
105+ prefetch_related : Sequence [str ] | None = None
106+ ):
107+ """Generic preloader that can be specialized for user or guild."""
108+ if prefetch_related is None :
109+ prefetch_related = []
110+
111+ func = partial (preloader , prefetch_related = prefetch_related )
112+
113+ check_decorator = commands .check (func )
114+
115+ return check_decorator (f ) if f is not None else check_decorator
116+
117+ preload_guild = partial (preload_x , preloader = _preload_guild )
118+ preload_user = partial (preload_x , preloader = _preload_user )
119+ preload_or_create_guild = partial (preload_x , preloader = _preload_or_create_guild )
120+ preload_or_create_user = partial (preload_x , preloader = _preload_or_create_user )
0 commit comments