Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions tanjun/commands/menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@
if typing.TYPE_CHECKING:
from collections import abc as collections

import typing_extensions
from typing_extensions import Self

_AnyCallbackSigT = typing.TypeVar(
"_AnyCallbackSigT", bound=collections.Callable[..., collections.Coroutine[typing.Any, typing.Any, None]]
)
_P = typing_extensions.ParamSpec("_P")
_T = typing.TypeVar("_T")
_CoroT = collections.Coroutine[typing.Any, typing.Any, _T]
_AnyCallbackSigT = typing.TypeVar("_AnyCallbackSigT", bound=collections.Callable[..., _CoroT[None]])
_MessageCallbackSigT = typing.TypeVar("_MessageCallbackSigT", bound=tanjun.MenuCallbackSig[hikari.Message])
_UserCallbackSigT = typing.TypeVar("_UserCallbackSigT", bound=tanjun.MenuCallbackSig[hikari.InteractionMember])

Expand Down Expand Up @@ -509,13 +511,10 @@ def __init__(
self._type: _MenuTypeT = type_ # MyPy bug causes this to need an explicit annotation.
self._wrapped_command = _wrapped_command

if typing.TYPE_CHECKING:
__call__: _AnyMenuCallbackSigT

else:

async def __call__(self, *args, **kwargs) -> None:
await self._callback(*args, **kwargs)
async def __call__(
self: MenuCommand[collections.Callable[_P, _CoroT[None]], _MenuTypeT], *args: _P.args, **kwargs: _P.kwargs
) -> None:
await self._callback(*args, **kwargs)

@property
def callback(self) -> _AnyMenuCallbackSigT:
Expand Down
15 changes: 8 additions & 7 deletions tanjun/commands/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@
from . import base

if typing.TYPE_CHECKING:
import typing_extensions
from typing_extensions import Self

_P = typing_extensions.ParamSpec("_P")
_AnyMessageCommandT = typing.TypeVar("_AnyMessageCommandT", bound=tanjun.MessageCommand[typing.Any])
_AnyCallbackSigT = typing.TypeVar("_AnyCallbackSigT", bound=collections.Callable[..., typing.Any])
_AnyCommandT = typing.Union[
Expand Down Expand Up @@ -241,13 +243,12 @@ def __init__(
def __repr__(self) -> str:
return f"Command <{self._names}>"

if typing.TYPE_CHECKING:
__call__: _MessageCallbackSigT

else:

async def __call__(self, *args, **kwargs) -> None:
await self._callback(*args, **kwargs)
async def __call__(
self: MessageCommand[collections.Callable[_P, collections.Coroutine[typing.Any, typing.Any, None]]],
*args: _P.args,
**kwargs: _P.kwargs,
) -> None:
await self._callback(*args, **kwargs)

@property
def callback(self) -> _MessageCallbackSigT:
Expand Down
17 changes: 8 additions & 9 deletions tanjun/commands/slash.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from . import base

if typing.TYPE_CHECKING:
import typing_extensions
from hikari.api import special_endpoints as special_endpoints_api
from typing_extensions import Self

Expand All @@ -83,6 +84,8 @@
]
_AnyConverterSig = typing.Union["ConverterSig[float]", "ConverterSig[int]", "ConverterSig[str]"]
_CallbackishT = typing.Union["_SlashCallbackSigT", _AnyCommandT["_SlashCallbackSigT"]]
_T = typing.TypeVar("_T")
_CoroT = collections.Coroutine[typing.Any, typing.Any, _T]

_IntAutocompleteSigT = typing.TypeVar("_IntAutocompleteSigT", bound=tanjun.AutocompleteSig[int])
_FloatAutocompleteSigT = typing.TypeVar("_FloatAutocompleteSigT", bound=tanjun.AutocompleteSig[float])
Expand All @@ -97,8 +100,7 @@
_P = typing_extensions.ParamSpec("_P")

_ConverterSig = collections.Callable[
typing_extensions.Concatenate[_ConvertT, _P],
typing.Union[collections.Coroutine[typing.Any, typing.Any, typing.Any], typing.Any],
typing_extensions.Concatenate[_ConvertT, _P], typing.Union[_CoroT[typing.Any], typing.Any],
]
ConverterSig = _ConverterSig[_ConvertT, ...]
"""Type hint of a slash command option converter.
Expand Down Expand Up @@ -1606,13 +1608,10 @@ def __init__(
self._tracked_options: dict[str, _TrackedOption] = {}
self._wrapped_command = _wrapped_command

if typing.TYPE_CHECKING:
__call__: _SlashCallbackSigT

else:

async def __call__(self, *args, **kwargs) -> None:
await self._callback(*args, **kwargs)
async def __call__(
self: SlashCommand[collections.Callable[_P, _CoroT[None]]], *args: _P.args, **kwargs: _P.kwargs
) -> None:
await self._callback(*args, **kwargs)

@property
def callback(self) -> _SlashCallbackSigT:
Expand Down
28 changes: 13 additions & 15 deletions tanjun/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,16 @@
if typing.TYPE_CHECKING:
from alluka import abc as alluka
from typing_extensions import Self
import typing_extensions

from . import abc as tanjun

_OtherCallbackT = typing.TypeVar("_OtherCallbackT", bound="_CallbackSig")
_P = typing_extensions.ParamSpec("_P")
_T = typing.TypeVar("_T")
_CoroT = collections.Coroutine[typing.Any, typing.Any, _T]

_CallbackSig = collections.Callable[..., collections.Coroutine[typing.Any, typing.Any, None]]
_CallbackSig = collections.Callable[..., "_CoroT[None]"]
_CallbackSigT = typing.TypeVar("_CallbackSigT", bound=_CallbackSig)


Expand Down Expand Up @@ -279,13 +283,10 @@ def iteration_count(self) -> int:
# <<inherited docstring from IntervalSchedule>>.
return self._iteration_count

if typing.TYPE_CHECKING:
__call__: _CallbackSigT

else:

async def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
await self._callback(*args, **kwargs)
async def __call__(
self: IntervalSchedule[collections.Callable[_P, _CoroT[None]]], *args: _P.args, **kwargs: _P.kwargs
) -> None:
await self._callback(*args, **kwargs)

def copy(self) -> Self:
# <<inherited docstring from IntervalSchedule>>.
Expand Down Expand Up @@ -1027,13 +1028,10 @@ def is_alive(self) -> bool:
# <<inherited docstring from IntervalSchedule>>.
return self._task is not None

if typing.TYPE_CHECKING:
__call__: _CallbackSigT

else:

async def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
await self._callback(*args, **kwargs)
async def __call__(
self: TimeSchedule[collections.Callable[_P, _CoroT[None]]], *args: _P.args, **kwargs: _P.kwargs
) -> None:
await self._callback(*args, **kwargs)

def copy(self) -> Self:
# <<inherited docstring from IntervalSchedule>>.
Expand Down
16 changes: 13 additions & 3 deletions tests/commands/test_menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,22 @@ def test___init___when_command_object(

@pytest.mark.asyncio()
async def test_call_dunder_method(self):
mock_callback: typing.Any = mock.AsyncMock()
mock_ctx = mock.Mock()
mock_message = mock.Mock()
check_called = mock.Mock()

async def mock_callback(ctx: tanjun.abc.Context, message: hikari.Message, other: str, /, *, b: int) -> None:
assert ctx is mock_ctx
assert message is mock_message
assert other == "ea"
assert b == 32
check_called()

command = tanjun.MenuCommand(mock_callback, hikari.CommandType.MESSAGE, "a")

await command(123, 321, "ea", b=32)
await command(mock_ctx, mock_message, "ea", b=32)

mock_callback.assert_awaited_once_with(123, 321, "ea", b=32)
check_called.assert_called_once()

def test_callback_property(self):
mock_callback = mock.Mock()
Expand Down
12 changes: 10 additions & 2 deletions tests/test_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,13 +1036,21 @@ def test_init_when_float_passed(self, kwargs: dict[str, typing.Any], expected_me

@pytest.mark.asyncio()
async def test_call_dunder_method(self):
mock_callback: typing.Any = mock.AsyncMock()
check_called = mock.Mock()

async def mock_callback(value: int, other: str, /, *, a: int, b: int) -> None:
assert value == 123
assert other == "32"
assert a == 432
assert b == 123
check_called()

interval = tanjun.schedules.TimeSchedule(mock_callback)

result = await interval(123, "32", a=432, b=123)

assert result is None
mock_callback.assert_awaited_once_with(123, "32", a=432, b=123)
check_called.assert_called_once()

def test_copy(self):
mock_callback: typing.Any = mock.AsyncMock()
Expand Down