Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tanjun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ async def main() -> None:
"with_owner_check",
"with_author_permission_check",
"with_own_permission_check",
"with_any_role_check",
# clients.py
"clients",
"as_loader",
Expand Down
74 changes: 74 additions & 0 deletions tanjun/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@
"with_owner_check",
"with_author_permission_check",
"with_own_permission_check",
"with_any_role_check",
"DmCheck",
"GuildCheck",
"NsfwCheck",
"SfwCheck",
"OwnerCheck",
"AuthorPermissionCheck",
"OwnPermissionCheck",
"HasAnyRoleCheck",
]

import typing
Expand Down Expand Up @@ -509,6 +511,43 @@ async def __call__(
return self._handle_result((permissions & self._permissions) == self._permissions)


class HasAnyRoleCheck(_Check):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One style thing, since this was originally written doc strings have been added to the check classes and their inits so that would prob be added for this as well

__slots__ = ("required_roles", "ids_only")

def __init__(
self,
roles: collections.Sequence[typing.Union[hikari.SnowflakeishOr[hikari.Role], str]] = [],
*,
error_message: typing.Optional[str] = "You do not have the required roles to use this command!",
halt_execution: bool = True,
) -> None:
super().__init__(error_message, halt_execution)
self.required_roles = roles
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another style thing, these attributes should be private now (so _required_roles and _ids_only

self.ids_only = all(isinstance(role, int) for role in self.required_roles)

async def __call__(self, ctx: tanjun_abc.Context, /) -> bool:
if not ctx.member:
return self._handle_result(False)

if not self.ids_only:
guild_roles = ctx.cache.get_roles_view_for_guild(ctx.member.guild_id) if ctx.cache else None
if not guild_roles:
guild_roles = await ctx.rest.fetch_roles(ctx.member.guild_id)
member_roles = [role for role in guild_roles if role.id in ctx.member.role_ids]
else:
member_roles = [guild_roles.get(role) for role in ctx.member.role_ids]
else:
member_roles = ctx.member.role_ids

return self._handle_result(any(map(self._check_roles, member_roles)))

def _check_roles(self, member_role: typing.Union[int, hikari.Role]) -> bool:
if isinstance(member_role, int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather than instance check here could the different checks be performed in the separate parts of the if not self._ids_only else statement to avoid the instance checks all together (if type checking doesn't quite like this you can just cast but i don't think it should be too strict for equality checks)

return any(member_role == check for check in self.required_roles)

return any(member_role.id == check or member_role.name == check for check in self.required_roles)


@typing.overload
def with_dm_check(command: CommandT, /) -> CommandT:
...
Expand Down Expand Up @@ -862,6 +901,41 @@ def with_own_permission_check(
)


def with_any_role_check(
roles: collections.Sequence[typing.Union[hikari.SnowflakeishOr[hikari.Role], int, str]] = [],
*,
error_message: typing.Optional[str] = "You do not have the required roles to use this command!",
halt_execution: bool = False,
) -> collections.Callable[[CommandT], CommandT]:
"""Only let a command run if the author has a specific role and the command is called in a guild.

Parameters
----------
roles: collections.Sequence[Union[SnowflakeishOr[Role], int, str]]
The author must have at least one (1) role in this list. (Role.name and Role.id are checked)

Other Parameters
----------------
error_message: Optional[str]
The error message raised if the member does not have a required role.

Defaults to 'You do not have the required roles to use this command!'
halt_execution: bool
Whether this check should raise `tanjun.errors.HaltExecution` to
end the execution search when it fails instead of returning `False`.

Defaults to `False`.

Returns
-------
collections.abc.Callable[[CommandT], CommandT]
A command decorator callback which adds the check.
"""
return lambda command: command.add_check(
HasAnyRoleCheck(roles, error_message=error_message, halt_execution=halt_execution)
)


def with_check(check: tanjun_abc.CheckSig, /) -> collections.Callable[[CommandT], CommandT]:
"""Add a generic check to a command.

Expand Down
23 changes: 23 additions & 0 deletions tests/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,29 @@ def test_with_own_permission_check(command: mock.Mock):
own_permission_check.assert_called_once_with(5412312, halt_execution=True, error_message="hi")


def test_with_has_any_role_check(command: mock.Mock):
with mock.patch.object(tanjun.checks, "HasAnyRoleCheck") as any_role_check:
assert (
tanjun.checks.with_any_role_check(
[
"Admin",
],
halt_execution=True,
error_message="hi",
)(command)
is command
)

command.add_check.assert_called_once_with(any_role_check.return_value)
any_role_check.assert_called_once_with(
[
"Admin",
],
halt_execution=True,
error_message="hi",
)


def test_with_check(command: mock.Mock):
mock_check = mock.Mock()

Expand Down