Skip to content

Commit db45700

Browse files
committed
✨ Refactor prefetch decorators to handle related values prefetch
1 parent 7595305 commit db45700

File tree

1 file changed

+52
-18
lines changed

1 file changed

+52
-18
lines changed

src/database/utils/preload.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
# Copyright (c) NiceBots
22
# SPDX-License-Identifier: MIT
3+
from collections.abc import Sequence, Callable
4+
from typing import Literal, overload
5+
from functools import partial
36

47
from discord.ext import commands
58

69
from src import custom
710
from src.database.models import Guild, User
811

9-
10-
async def _preload_user(ctx: custom.Context) -> bool:
12+
async def _preload_user(ctx: custom.Context, prefetch_related: Sequence[str]) -> Literal[True]:
1113
"""Preload the user object into the context object.
1214
1315
Args:
@@ -20,16 +22,13 @@ async def _preload_user(ctx: custom.Context) -> bool:
2022
2123
"""
2224
if isinstance(ctx, custom.ExtContext):
23-
ctx.user_obj = await User.get_or_none(id=ctx.author.id) if ctx.author else None
25+
ctx.user_obj = await User.get_or_none(id=ctx.author.id).prefetch_related(*prefetch_related) if ctx.author else None
2426
else:
25-
ctx.user_obj = await User.get_or_none(id=ctx.user.id) if ctx.user else None
27+
ctx.user_obj = await User.get_or_none(id=ctx.user.id).prefetch_related(*prefetch_related) if ctx.user else None
2628
return True
2729

2830

29-
preload_user = commands.check(_preload_user) # pyright: ignore [reportArgumentType]
30-
31-
32-
async def _preload_guild(ctx: custom.Context) -> bool:
31+
async def _preload_guild(ctx: custom.Context, prefetch_related: Sequence[str]) -> Literal[True]:
3332
"""Preload the guild object into the context object.
3433
3534
Args:
@@ -41,14 +40,12 @@ async def _preload_guild(ctx: custom.Context) -> bool:
4140
bool: (True) always.
4241
4342
"""
44-
ctx.guild_obj = await Guild.get_or_none(id=ctx.guild.id) if ctx.guild else None
43+
ctx.guild_obj = await Guild.get_or_none(id=ctx.guild.id).prefetch_related(*prefetch_related) if ctx.guild else None
4544
return True
4645

4746

48-
preload_guild = commands.check(_preload_guild) # pyright: ignore [reportArgumentType]
4947

50-
51-
async def _preload_or_create_user(ctx: custom.Context) -> bool:
48+
async def _preload_or_create_user(ctx: custom.Context, prefetch_related: Sequence[str]) -> Literal[True]:
5249
"""Preload or create the user object into the context object. If the user object does not exist, create it.
5350
5451
Args:
@@ -60,14 +57,16 @@ async def _preload_or_create_user(ctx: custom.Context) -> bool:
6057
bool: (True) always.
6158
6259
"""
63-
ctx.user_obj, _ = await User.get_or_create(id=ctx.author.id) if ctx.author else (None, None)
60+
user: User | None
61+
user, _ = await User.get_or_create(id=ctx.author.id) if ctx.author else (None, None)
62+
if user is not None:
63+
await user.fetch_related(*prefetch_related)
64+
ctx.user_obj = user
6465
return True
6566

6667

67-
preload_or_create_user = commands.check(_preload_or_create_user)
68-
6968

70-
async def _preload_or_create_guild(ctx: custom.Context) -> bool:
69+
async def _preload_or_create_guild(ctx: custom.Context, prefetch_related: Sequence[str]) -> Literal[True]:
7170
"""Preload or create the guild object into the context object. If the guild object does not exist, create it.
7271
7372
Args:
@@ -79,8 +78,43 @@ async def _preload_or_create_guild(ctx: custom.Context) -> bool:
7978
bool: (True) always.
8079
8180
"""
82-
ctx.guild_obj, _ = await Guild.get_or_create(id=ctx.guild.id) if ctx.guild else (None, None)
81+
guild: Guild | None
82+
guild, _ = await Guild.get_or_create(id=ctx.guild.id) if ctx.guild else (None, None)
83+
if guild is not None:
84+
await guild.fetch_related(*prefetch_related)
85+
ctx.guild_obj = guild
8386
return True
8487

8588

86-
preload_or_create_guild = commands.check(_preload_or_create_guild) # pyright: ignore [reportArgumentType]
89+
type PreloadFunction = Callable[[custom.Context, Sequence[str]], Literal[True]]
90+
91+
@overload
92+
def preload_x[T](f: T, preloader: PreloadFunction, prefetch_related: Sequence[str] | None = None) -> T:
93+
"""When used as a direct decorator: @preload_or_create_user"""
94+
...
95+
96+
@overload
97+
def preload_x(f: None = None, *, preloader: PreloadFunction, prefetch_related: Sequence[str] | None = None) -> Callable[[T], T]:
98+
"""When used with arguments: @preload_or_create_user(prefetch_related=[...])"""
99+
...
100+
101+
def preload_x[T](
102+
f: Callable[[T], T] | None = None,
103+
*,
104+
preloader: PreloadFunction,
105+
prefetch_related: Sequence[str] | None = None
106+
):
107+
"""Generic preloader that can be specialized for user or guild."""
108+
if prefetch_related is None:
109+
prefetch_related = []
110+
111+
func = partial(preloader, prefetch_related=prefetch_related)
112+
113+
check_decorator = commands.check(func)
114+
115+
return check_decorator(f) if f is not None else check_decorator
116+
117+
preload_guild = partial(preload_x, preloader=_preload_guild)
118+
preload_user = partial(preload_x, preloader=_preload_user)
119+
preload_or_create_guild = partial(preload_x, preloader=_preload_or_create_guild)
120+
preload_or_create_user = partial(preload_x, preloader=_preload_or_create_user)

0 commit comments

Comments
 (0)