diff --git a/tanjun/components.py b/tanjun/components.py index ad2e91717..b1c905689 100644 --- a/tanjun/components.py +++ b/tanjun/components.py @@ -132,6 +132,10 @@ class Component(abc.Component): ---------- checks : typing.Optional[collections.abc.Iterable[abc.CheckSig]] Iterable of check callbacks to set for this component, if provided. + slash_checks : typing.Optional[collections.abc.Iterable[tanjun.abc.CheckSig]] + The slash check callbacks to set for the component, if provided. + message_checks : typing.Optional[collections.abc.Iterable[tanjun.abc.CheckSig]] + The message check callbacks to set for the component, if provided. hooks : typing.Optional[tanjun.abc.AnyHooks] The hooks this component should add to the execution of all its commands (message and slash). @@ -162,11 +166,13 @@ class Component(abc.Component): "_is_strict", "_listeners", "_message_commands", + "_message_checks", "_message_hooks", "_metadata", "_name", "_names_to_commands", "_slash_commands", + "_slash_checks", "_slash_hooks", ) @@ -174,6 +180,8 @@ def __init__( self, *, checks: typing.Optional[collections.Iterable[abc.CheckSig]] = None, + slash_checks: typing.Optional[collections.Iterable[abc.CheckSig]] = None, + message_checks: typing.Optional[collections.Iterable[abc.CheckSig]] = None, hooks: typing.Optional[abc.AnyHooks] = None, slash_hooks: typing.Optional[abc.SlashHooks] = None, message_hooks: typing.Optional[abc.MessageHooks] = None, @@ -191,11 +199,17 @@ def __init__( self._is_strict = strict self._listeners: dict[type[base_events.Event], list[abc.ListenerCallbackSig]] = {} self._message_commands: list[abc.MessageCommand] = [] + self._message_checks: list[checks_.InjectableCheck] = ( + [checks_.InjectableCheck(check) for check in dict.fromkeys(slash_checks)] if slash_checks else [] + ) self._message_hooks = message_hooks self._metadata: dict[typing.Any, typing.Any] = {} self._name = name or base64.b64encode(random.randbytes(32)).decode() self._names_to_commands: dict[str, abc.MessageCommand] = {} self._slash_commands: dict[str, abc.BaseSlashCommand] = {} + self._slash_checks: list[checks_.InjectableCheck] = ( + [checks_.InjectableCheck(check) for check in dict.fromkeys(message_checks)] if message_checks else [] + ) self._slash_hooks = slash_hooks if load_from_attributes and type(self) is not Component: # No need to run this on the base class. @@ -228,6 +242,10 @@ def name(self) -> str: def slash_commands(self) -> collections.Collection[abc.BaseSlashCommand]: return self._slash_commands.copy().values() + @property + def slash_checks(self) -> collections.Collection[abc.CheckSig]: + return tuple(check.callback for check in self._slash_checks) + @property def slash_hooks(self) -> typing.Optional[abc.SlashHooks]: return self._slash_hooks @@ -236,13 +254,17 @@ def slash_hooks(self) -> typing.Optional[abc.SlashHooks]: def message_commands(self) -> collections.Collection[abc.MessageCommand]: return self._message_commands.copy() + @property + def message_checks(self) -> collections.Collection[abc.CheckSig]: + return tuple(check.callback for check in self._message_checks) + @property def message_hooks(self) -> typing.Optional[abc.MessageHooks]: return self._message_hooks @property def needs_injector(self) -> bool: - return any(check.needs_injector for check in self._checks) + return any(check.needs_injector for check in self._checks + self._message_checks + self._slash_checks) @property def listeners( @@ -258,12 +280,14 @@ def copy(self: _ComponentT, *, _new: bool = True) -> _ComponentT: if not _new: self._checks = [check.copy() for check in self._checks] self._slash_commands = {name: command.copy() for name, command in self._slash_commands.items()} + self._slash_checks = [check.copy() for check in self._slash_checks] self._hooks = self._hooks.copy() if self._hooks else None self._listeners = { event: [copy.copy(listener) for listener in listeners] for event, listeners in self._listeners.items() } commands = {command: command.copy() for command in self._message_commands} self._message_commands = list(commands.values()) + self._message_checks = [check.copy() for check in self._message_checks] self._metadata = self._metadata.copy() self._names_to_commands = {name: commands[command] for name, command in self._names_to_commands.items()} return self @@ -317,6 +341,34 @@ def with_check(self, check: abc.CheckSigT, /) -> abc.CheckSigT: self.add_check(check) return check + def add_slash_check(self: _ComponentT, check: abc.CheckSig, /) -> _ComponentT: + if check not in self._slash_checks: + self._slash_checks.append(checks_.InjectableCheck(check)) + + return self + + def remove_slash_check(self: _ComponentT, check: abc.CheckSig, /) -> _ComponentT: + self._slash_checks.remove(typing.cast("checks_.InjectableCheck", check)) + return self + + def with_slash_check(self, check: abc.CheckSigT, /) -> abc.CheckSigT: + self.add_slash_check(check) + return check + + def add_message_check(self: _ComponentT, check: abc.CheckSig, /) -> _ComponentT: + if check not in self._message_checks: + self._message_checks.append(checks_.InjectableCheck(check)) + + return self + + def remove_message_check(self: _ComponentT, check: abc.CheckSig, /) -> _ComponentT: + self._message_checks.remove(typing.cast("checks_.InjectableCheck", check)) + return self + + def with_message_check(self, check: abc.CheckSigT, /) -> abc.CheckSigT: + self.add_message_check(check) + return check + def add_client_callback(self: _ComponentT, event_name: str, callback: abc.MetaEventSig, /) -> _ComponentT: event_name = event_name.lower() try: @@ -576,7 +628,11 @@ def unbind_client(self, client: abc.Client, /) -> None: self._client = None async def _check_context(self, ctx: abc.Context, /) -> bool: - return await utilities.gather_checks(ctx, self._checks) + if abc.SlashContext in type(ctx).mro(): + additional_checks = self._slash_checks + else: + additional_checks = self._message_checks + return await utilities.gather_checks(ctx, self._checks + additional_checks) async def _check_message_context( self, ctx: abc.MessageContext, / diff --git a/tests/test_components.py b/tests/test_components.py index f68b1f0d2..732935a9c 100644 --- a/tests/test_components.py +++ b/tests/test_components.py @@ -160,6 +160,86 @@ def test_with_check(self): assert result is mock_check + def test_add_slash_check(self): + mock_check = mock.Mock() + component = tanjun.Component() + + result = component.add_slash_check(mock_check) + + assert result is component + + def test_add_slash_check_when_already_present(self): + mock_check = mock.Mock() + component = tanjun.Component().add_slash_check(mock_check) + + with mock.patch.object(tanjun.checks, "InjectableCheck") as InjectableCheck: + result = component.add_slash_check(mock_check) + + InjectableCheck.assert_not_called() + + assert list(component.slash_checks).count(mock_check) == 1 + assert result is component + + def test_remove_slash_check(self): + component = tanjun.Component().add_slash_check(mock.Mock()) + + result = component.remove_slash_check(next(iter(component.slash_checks))) + + assert result is component + assert not component.slash_checks + + def test_remove_slash_check_when_not_present(self): + with pytest.raises(ValueError, match=".+"): + tanjun.Component().remove_slash_check(mock.Mock()) + + def test_with_slash_check(self): + mock_check = mock.Mock() + component = tanjun.Component() + + result = component.with_slash_check(mock_check) + + assert result is mock_check + + def test_add_message_check(self): + mock_check = mock.Mock() + component = tanjun.Component() + + result = component.add_message_check(mock_check) + + assert result is component + + def test_add_message_check_when_already_present(self): + mock_check = mock.Mock() + component = tanjun.Component().add_message_check(mock_check) + + with mock.patch.object(tanjun.checks, "InjectableCheck") as InjectableCheck: + result = component.add_message_check(mock_check) + + InjectableCheck.assert_not_called() + + assert list(component.message_checks).count(mock_check) == 1 + assert result is component + + def test_remove_message_check(self): + component = tanjun.Component().add_message_check(mock.Mock()) + + result = component.remove_message_check(next(iter(component.message_checks))) + + assert result is component + assert not component.message_checks + + def test_remove_message_check_when_not_present(self): + with pytest.raises(ValueError, match=".+"): + tanjun.Component().remove_message_check(mock.Mock()) + + def test_with_message_check(self): + mock_check = mock.Mock() + component = tanjun.Component() + + result = component.with_message_check(mock_check) + + assert result is mock_check + def test_add_client_callback(self): mock_callback = mock.Mock() mock_other_callback = mock.Mock()