Skip to content

Commit 42cdf62

Browse files
Add app commands declared client callback
1 parent 02f23bf commit 42cdf62

File tree

2 files changed

+67
-7
lines changed

2 files changed

+67
-7
lines changed

tanjun/abc.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3631,12 +3631,40 @@ async def open(self) -> None:
36313631
"""
36323632

36333633

3634+
class DeclaredCommands(abc.ABC):
3635+
__slots__ = ()
3636+
3637+
@property
3638+
@abc.abstractmethod
3639+
def builders(self) -> collections.Sequence[hikari.api.CommandBuilder]:
3640+
"""The declared command builders."""
3641+
3642+
@property
3643+
@abc.abstractmethod
3644+
def commands(self) -> collections.Sequence[hikari.PartialCommand]:
3645+
"""The declared command objects."""
3646+
3647+
@property
3648+
@abc.abstractmethod
3649+
def guild_id(self) -> typing.Optional[hikari.Snowflake]:
3650+
"""Id of the guild these commands were declared for.
3651+
3652+
This will be [None][] if they were declared globally.
3653+
"""
3654+
3655+
36343656
class ClientCallbackNames(str, enum.Enum):
36353657
"""Enum of the standard client callback names.
36363658
36373659
These should be dispatched by all [tanjun.abc.Client][] implementations.
36383660
"""
36393661

3662+
APP_COMMANDS_DECLARED = "app_commands_delcared"
3663+
"""Called when the application commands are declared through the client.
3664+
3665+
One positional argument of type [DeclaredCommands][].
3666+
"""
3667+
36403668
CLOSED = "closed"
36413669
"""Called when the client has finished closing.
36423670

tanjun/clients.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,33 @@ async def __call__(self) -> None:
491491
self.client.remove_client_callback(ClientCallbackNames.STARTING, self)
492492

493493

494+
class _DeclaredCommands(tanjun.DeclaredCommands):
495+
__slots__ = ("_builders", "_commands", "_guild_id")
496+
497+
def __init__(
498+
self,
499+
builders: collections.Sequence[hikari.api.CommandBuilder],
500+
commands: collections.Sequence[hikari.PartialCommand],
501+
guild_id: typing.Optional[hikari.Snowflake],
502+
/,
503+
) -> None:
504+
self._builders = builders
505+
self._commands = commands
506+
self._guild_id = guild_id
507+
508+
@property
509+
def builders(self) -> collections.Sequence[hikari.api.CommandBuilder]:
510+
return self._builders
511+
512+
@property
513+
def commands(self) -> collections.Sequence[hikari.PartialCommand]:
514+
return self._commands
515+
516+
@property
517+
def guild_id(self) -> typing.Optional[hikari.Snowflake]:
518+
return self._guild_id
519+
520+
494521
class Client(tanjun.Client):
495522
"""Tanjun's standard [tanjun.abc.Client][] implementation.
496523
@@ -1312,15 +1339,15 @@ async def declare_application_commands(
13121339
user_ids = user_ids or {}
13131340
names_to_commands: dict[tuple[hikari.CommandType, str], tanjun.AppCommand[typing.Any]] = {}
13141341
conflicts: set[tuple[hikari.CommandType, str]] = set()
1315-
builders: dict[tuple[hikari.CommandType, str], hikari.api.CommandBuilder] = {}
1342+
builders_dict: dict[tuple[hikari.CommandType, str], hikari.api.CommandBuilder] = {}
13161343
message_count = 0
13171344
slash_count = 0
13181345
user_count = 0
13191346

13201347
for command in commands:
13211348
key = (command.type, command.name)
13221349
names_to_commands[key] = command
1323-
if key in builders:
1350+
if key in builders_dict:
13241351
conflicts.add(key)
13251352

13261353
builder = command.build()
@@ -1345,7 +1372,7 @@ async def declare_application_commands(
13451372
if builder.is_dm_enabled is hikari.UNDEFINED:
13461373
builder.set_is_dm_enabled(self.dms_enabled_for_app_cmds)
13471374

1348-
builders[key] = builder
1375+
builders_dict[key] = builder
13491376

13501377
if conflicts:
13511378
raise ValueError(
@@ -1367,16 +1394,17 @@ async def declare_application_commands(
13671394

13681395
if not force:
13691396
registered_commands = await self._rest.fetch_application_commands(application, guild=guild)
1370-
if len(registered_commands) == len(builders) and all(
1371-
_cmp_command(builders.get((c.type, c.name)), c) for c in registered_commands
1397+
if len(registered_commands) == len(builders_dict) and all(
1398+
_cmp_command(builders_dict.get((c.type, c.name)), c) for c in registered_commands
13721399
):
13731400
_LOGGER.info(
13741401
"Skipping bulk declare for %s application commands since they're already declared", target_type
13751402
)
13761403
return registered_commands
13771404

1378-
_LOGGER.info("Bulk declaring %s %s application commands", len(builders), target_type)
1379-
responses = await self._rest.set_application_commands(application, list(builders.values()), guild=guild)
1405+
_LOGGER.info("Bulk declaring %s %s application commands", len(builders_dict), target_type)
1406+
builders = list(builders_dict.values())
1407+
responses = await self._rest.set_application_commands(application, builders, guild=guild)
13801408

13811409
for response in responses:
13821410
if not guild:
@@ -1390,6 +1418,10 @@ async def declare_application_commands(
13901418
", ".join(f"{response.type}-{response.name}: {response.id}" for response in responses),
13911419
)
13921420

1421+
await self.dispatch_client_callback(
1422+
tanjun.ClientCallbackNames.APP_COMMANDS_DECLARED,
1423+
_DeclaredCommands(builders, responses, None if guild is hikari.UNDEFINED else hikari.Snowflake(guild)),
1424+
)
13931425
return responses
13941426

13951427
def set_auto_defer_after(self: _ClientT, time: typing.Optional[float], /) -> _ClientT:

0 commit comments

Comments
 (0)