Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/callosum/rpc/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
Union,
)

import attrs

from ..abc import (
AbstractChannel,
AbstractDeserializer,
Expand Down Expand Up @@ -370,9 +368,9 @@ async def invoke(
server_cancelled = True
raise asyncio.CancelledError
elif response.msgtype == RPCMessageTypes.FAILURE:
raise RPCUserError(*attrs.astuple(response.metadata))
raise RPCUserError.from_err_metadata(response.metadata)
elif response.msgtype == RPCMessageTypes.ERROR:
raise RPCInternalError(*attrs.astuple(response.metadata))
raise RPCInternalError.from_err_metadata(response.metadata)
return upper_result
except (asyncio.TimeoutError, asyncio.CancelledError):
# propagate cancellation to the connected peer
Expand Down
44 changes: 25 additions & 19 deletions src/callosum/rpc/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,47 @@
from ..exceptions import CallosumError
from __future__ import annotations

from typing import TYPE_CHECKING, Self

class RPCError(CallosumError):
"""
A base exception for all RPC-specific errors.
"""
from ..exceptions import CallosumError

pass
if TYPE_CHECKING:
from .message import ErrorMetadata


class RPCUserError(RPCError):
class RPCError(CallosumError):
"""
Represents an error caused in user-defined handlers.
A base exception for all RPC-specific errors.
"""

name: str
repr: str
traceback: str
exceptions: tuple

def __init__(self, name: str, repr_: str, tb: str, *args):
def __init__(self, name: str, repr_: str, tb: str, exceptions: tuple, *args):
super().__init__(name, repr_, tb, *args)
self.name = name
self.repr = repr_
self.traceback = tb
self.exceptions = exceptions

@classmethod
def from_err_metadata(cls, metadata: ErrorMetadata) -> Self:
return cls(
metadata.name,
metadata.repr,
metadata.traceback,
tuple(cls.from_err_metadata(err) for err in metadata.sub_errors),
)

class RPCInternalError(RPCError):

class RPCUserError(RPCError):
"""
Represents an error caused in Calloum's internal logic.
Represents an error caused in user-defined handlers.
"""

name: str
repr: str
traceback: str

def __init__(self, name: str, repr_: str, tb: str, *args):
super().__init__(name, tb, *args)
self.name = name
self.repr = repr_
self.traceback = tb
class RPCInternalError(RPCError):
"""
Represents an error caused in Calloum's internal logic.
"""
61 changes: 51 additions & 10 deletions src/callosum/rpc/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,55 @@ class ErrorMetadata(Metadata):
repr: str
traceback: str

sub_errors: tuple[ErrorMetadata, ...] = attrs.field(factory=tuple)

@classmethod
def decode(cls, buffer: bytes) -> Any:
if not buffer:
return None
values = munpackb(buffer)
match values:
case (name, repr, traceback, raw_sub_errors):
return cls(
name,
repr,
traceback,
tuple(cls.decode(raw_error) for raw_error in raw_sub_errors),
)
case _:
return cls(*values)

def encode(self) -> bytes:
values = [
self.name,
self.repr,
self.traceback,
[err.encode() for err in self.sub_errors],
]
return mpackb(values)

@classmethod
def from_exception(
cls, exc: BaseExceptionGroup | BaseException, formatted_traceback: str
) -> ErrorMetadata:
match exc:
case BaseExceptionGroup():
return ErrorMetadata(
"ExceptionGroup",
repr(exc),
formatted_traceback,
sub_errors=tuple(
cls.from_exception(sub_exc, formatted_traceback)
for sub_exc in exc.exceptions
),
)
case _:
return ErrorMetadata(
type(exc).__name__,
repr(exc),
formatted_traceback,
)


@attrs.define(frozen=True, slots=True)
class NullMetadata(Metadata):
Expand Down Expand Up @@ -150,11 +199,7 @@ def failure(cls, request):
request.method,
request.order_key,
request.client_seq_id,
ErrorMetadata(
exc_info[0].__name__,
repr(exc_info[1]),
traceback.format_exc(),
),
ErrorMetadata.from_exception(exc_info[1], traceback.format_exc()),
None,
)

Expand All @@ -174,11 +219,7 @@ def error(cls, request):
request.method,
request.order_key,
request.client_seq_id,
ErrorMetadata(
exc_info[0].__name__,
repr(exc_info[1]),
traceback.format_exc(),
),
ErrorMetadata.from_exception(exc_info[1], traceback.format_exc()),
None,
)

Expand Down