diff --git a/packages/faststream-stomp/faststream_stomp/__init__.py b/packages/faststream-stomp/faststream_stomp/__init__.py index 3d69243..800786d 100644 --- a/packages/faststream-stomp/faststream_stomp/__init__.py +++ b/packages/faststream-stomp/faststream_stomp/__init__.py @@ -1,5 +1,5 @@ from faststream_stomp.broker import StompBroker -from faststream_stomp.message import StompStreamMessage +from faststream_stomp.models import StompPublishCommand, StompStreamMessage from faststream_stomp.publisher import StompPublisher from faststream_stomp.router import StompRoute, StompRoutePublisher, StompRouter from faststream_stomp.subscriber import StompSubscriber @@ -7,6 +7,7 @@ __all__ = [ "StompBroker", + "StompPublishCommand", "StompPublisher", "StompRoute", "StompRoutePublisher", diff --git a/packages/faststream-stomp/faststream_stomp/broker.py b/packages/faststream-stomp/faststream_stomp/broker.py index 6521a75..f3567c1 100644 --- a/packages/faststream-stomp/faststream_stomp/broker.py +++ b/packages/faststream-stomp/faststream_stomp/broker.py @@ -1,22 +1,31 @@ import asyncio import logging import types -from collections.abc import Callable, Iterable, Mapping, Sequence -from typing import Any, Unpack +import typing +from collections.abc import Iterable, Sequence +from typing import Any import anyio import stompman -from fast_depends.dependencies import Depends -from faststream.asyncapi.schema import Tag, TagDict -from faststream.broker.core.usecase import BrokerUsecase -from faststream.broker.types import BrokerMiddleware, CustomCallable -from faststream.log.logging import get_broker_logger +from fast_depends.dependencies import Dependant +from faststream import ContextRepo, PublishType +from faststream._internal.basic_types import AnyDict, LoggerProto, SendableMessage +from faststream._internal.broker import BrokerUsecase +from faststream._internal.broker.registrator import Registrator +from faststream._internal.configs import BrokerConfig +from faststream._internal.constants import EMPTY +from faststream._internal.di import FastDependsConfig +from faststream._internal.logger import DefaultLoggerStorage, make_logger_state +from faststream._internal.logger.logging import get_broker_logger +from faststream._internal.types import BrokerMiddleware, CustomCallable from faststream.security import BaseSecurity -from faststream.types import EMPTY, AnyDict, Decorator, LoggerProto, SendableMessage +from faststream.specification.schema import BrokerSpec +from faststream.specification.schema.extra import Tag, TagDict +from faststream_stomp.models import BrokerConfigWithStompClient, StompPublishCommand from faststream_stomp.publisher import StompProducer, StompPublisher from faststream_stomp.registrator import StompRegistrator -from faststream_stomp.subscriber import StompLogContext, StompSubscriber +from faststream_stomp.subscriber import StompSubscriber class StompSecurity(BaseSecurity): @@ -40,83 +49,111 @@ def _handle_listen_task_done(listen_task: asyncio.Task[None]) -> None: raise SystemExit(1) -class StompBroker(StompRegistrator, BrokerUsecase[stompman.MessageFrame, stompman.Client]): - _subscribers: Mapping[int, StompSubscriber] - _publishers: Mapping[int, StompPublisher] +class StompParamsStorage(DefaultLoggerStorage): __max_msg_id_ln = 10 _max_channel_name = 4 + def get_logger(self, *, context: ContextRepo) -> LoggerProto: + if logger := self._get_logger_ref(): + return logger + logger = get_broker_logger( + name="stomp", + default_context={"destination": "", "message_id": ""}, + message_id_ln=self.__max_msg_id_ln, + fmt=( + "%(asctime)s %(levelname)-8s - " + f"%(destination)-{self._max_channel_name}s | " + f"%(message_id)-{self.__max_msg_id_ln}s " + "- %(message)s" + ), + context=context, + log_level=self.logger_log_level, + ) + self._logger_ref.add(logger) + return logger + + +class StompBroker( + StompRegistrator, + BrokerUsecase[ + stompman.MessageFrame, + stompman.Client, + BrokerConfig, # Using BrokerConfig to avoid typing issues when passing broker to FastStream app + ], +): + _subscribers: list[StompSubscriber] # type: ignore[assignment] + _publishers: list[StompPublisher] # type: ignore[assignment] + def __init__( self, client: stompman.Client, *, decoder: CustomCallable | None = None, parser: CustomCallable | None = None, - dependencies: Iterable[Depends] = (), - middlewares: Sequence[BrokerMiddleware[stompman.MessageFrame]] = (), + dependencies: Iterable[Dependant] = (), + middlewares: Sequence[BrokerMiddleware[stompman.MessageFrame, StompPublishCommand]] = (), graceful_timeout: float | None = 15.0, + routers: Sequence[Registrator[stompman.MessageFrame]] = (), # Logging args logger: LoggerProto | None = EMPTY, log_level: int = logging.INFO, # FastDepends args apply_types: bool = True, - validate: bool = True, - _get_dependant: Callable[..., Any] | None = None, - _call_decorators: Iterable[Decorator] = (), - # AsyncAPI kwargs, + # AsyncAPI args description: str | None = None, - tags: Iterable[Tag | TagDict] | None = None, + tags: Iterable[Tag | TagDict] = (), ) -> None: - super().__init__( - client=client, # **connection_kwargs - decoder=decoder, - parser=parser, - dependencies=dependencies, - middlewares=middlewares, + broker_config = BrokerConfigWithStompClient( + broker_middlewares=middlewares, # type: ignore[arg-type] + broker_parser=parser, + broker_decoder=decoder, + logger=make_logger_state( + logger=logger, + log_level=log_level, + default_storage_cls=StompParamsStorage, # type: ignore[type-abstract] + ), + fd_config=FastDependsConfig(use_fastdepends=apply_types), + broker_dependencies=dependencies, graceful_timeout=graceful_timeout, - logger=logger, - log_level=log_level, - apply_types=apply_types, - validate=validate, - _get_dependant=_get_dependant, - _call_decorators=_call_decorators, + extra_context={"broker": self}, + producer=StompProducer(client), + client=client, + ) + specification = BrokerSpec( + url=[f"{one_server.host}:{one_server.port}" for one_server in broker_config.client.servers], protocol="STOMP", protocol_version="1.2", description=description, tags=tags, - asyncapi_url=[f"{one_server.host}:{one_server.port}" for one_server in client.servers], security=StompSecurity(), - default_logger=get_broker_logger( - name="stomp", default_context={"channel": ""}, message_id_ln=self.__max_msg_id_ln - ), ) - self._attempted_to_connect = False - - async def start(self) -> None: - await super().start() - for handler in self._subscribers.values(): - self._log(f"`{handler.call_name}` waiting for messages", extra=handler.get_log_context(None)) - await handler.start() + super().__init__(config=broker_config, specification=specification, routers=routers) + self._attempted_to_connect = False - async def _connect(self, client: stompman.Client) -> stompman.Client: # type: ignore[override] + async def _connect(self) -> stompman.Client: if self._attempted_to_connect: - return client + return self.config.broker_config.client self._attempted_to_connect = True - self._producer = StompProducer(client) - await client.__aenter__() - client._listen_task.add_done_callback(_handle_listen_task_done) # noqa: SLF001 - return client + await self.config.broker_config.client.__aenter__() + self.config.broker_config.client._listen_task.add_done_callback(_handle_listen_task_done) + return self.config.broker_config.client + + async def start(self) -> None: + await self.connect() + await super().start() - async def _close( + async def stop( self, exc_type: type[BaseException] | None = None, exc_val: BaseException | None = None, exc_tb: types.TracebackType | None = None, ) -> None: + for sub in self.subscribers: + await sub.stop() if self._connection: await self._connection.__aexit__(exc_type, exc_val, exc_tb) - return await super()._close(exc_type, exc_val, exc_tb) + self.running = False async def ping(self, timeout: float | None = None) -> bool: sleep_time = (timeout or 10) / 10 @@ -135,22 +172,7 @@ async def ping(self, timeout: float | None = None) -> bool: return False # pragma: no cover - def get_fmt(self) -> str: - # `StompLogContext` - return ( - "%(asctime)s %(levelname)-8s - " - f"%(destination)-{self._max_channel_name}s | " - f"%(message_id)-{self.__max_msg_id_ln}s " - "- %(message)s" - ) - - def _setup_log_context(self, **log_context: Unpack[StompLogContext]) -> None: ... # type: ignore[override] - - @property - def _subscriber_setup_extra(self) -> "AnyDict": - return {**super()._subscriber_setup_extra, "client": self._connection} - - async def publish( # type: ignore[override] + async def publish( self, message: SendableMessage, destination: str, @@ -158,19 +180,44 @@ async def publish( # type: ignore[override] correlation_id: str | None = None, headers: dict[str, str] | None = None, ) -> None: - await super().publish( + publish_command = StompPublishCommand( message, - producer=self._producer, - correlation_id=correlation_id, + _publish_type=PublishType.PUBLISH, destination=destination, + correlation_id=correlation_id, headers=headers, ) + return typing.cast("None", await self._basic_publish(publish_command, producer=self.config.producer)) async def request( # type: ignore[override] self, - msg: Any, # noqa: ANN401 + message: SendableMessage, + destination: str, *, correlation_id: str | None = None, headers: dict[str, str] | None = None, ) -> Any: # noqa: ANN401 - return await super().request(msg, producer=self._producer, correlation_id=correlation_id, headers=headers) + publish_command = StompPublishCommand( + message, + _publish_type=PublishType.REQUEST, + destination=destination, + correlation_id=correlation_id, + headers=headers, + ) + return await self._basic_request(publish_command, producer=self.config.producer) + + async def publish_batch( # type: ignore[override] + self, + *_messages: SendableMessage, + destination: str, + correlation_id: str | None = None, + headers: dict[str, str] | None = None, + ) -> None: + publish_command = StompPublishCommand( + "", + _publish_type=PublishType.PUBLISH, + destination=destination, + correlation_id=correlation_id, + headers=headers, + ) + return typing.cast("None", await self._basic_publish_batch(publish_command, producer=self.config.producer)) diff --git a/packages/faststream-stomp/faststream_stomp/message.py b/packages/faststream-stomp/faststream_stomp/message.py deleted file mode 100644 index 5452e31..0000000 --- a/packages/faststream-stomp/faststream_stomp/message.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import Self, cast - -import stompman -from faststream.broker.message import StreamMessage, gen_cor_id - - -class StompStreamMessage(StreamMessage[stompman.AckableMessageFrame]): - async def ack(self) -> None: - if not self.committed: - await self.raw_message.ack() - return await super().ack() - - async def nack(self) -> None: - if not self.committed: - await self.raw_message.nack() - return await super().nack() - - async def reject(self) -> None: - if not self.committed: - await self.raw_message.nack() - return await super().reject() - - @classmethod - async def from_frame(cls, message: stompman.AckableMessageFrame) -> Self: - return cls( - raw_message=message, - body=message.body, - headers=cast("dict[str, str]", message.headers), - content_type=message.headers.get("content-type"), - message_id=message.headers["message-id"], - correlation_id=cast("str", message.headers.get("correlation-id", gen_cor_id())), - ) diff --git a/packages/faststream-stomp/faststream_stomp/models.py b/packages/faststream-stomp/faststream_stomp/models.py new file mode 100644 index 0000000..8d8051e --- /dev/null +++ b/packages/faststream-stomp/faststream_stomp/models.py @@ -0,0 +1,100 @@ +from dataclasses import dataclass, field +from typing import Self, cast + +import stompman +from faststream import AckPolicy, PublishCommand, StreamMessage +from faststream._internal.configs import ( + BrokerConfig, + PublisherSpecificationConfig, + PublisherUsecaseConfig, + SubscriberSpecificationConfig, + SubscriberUsecaseConfig, +) +from faststream._internal.types import AsyncCallable +from faststream._internal.utils.functions import to_async +from faststream.message import decode_message, gen_cor_id + + +class StompStreamMessage(StreamMessage[stompman.AckableMessageFrame]): + async def ack(self) -> None: + if not self.committed: + await self.raw_message.ack() + return await super().ack() + + async def nack(self) -> None: + if not self.committed: + await self.raw_message.nack() + return await super().nack() + + async def reject(self) -> None: + if not self.committed: + await self.raw_message.nack() + return await super().reject() + + @classmethod + async def from_frame(cls, message: stompman.AckableMessageFrame) -> Self: + return cls( + raw_message=message, + body=message.body, + headers=cast("dict[str, str]", message.headers), + content_type=message.headers.get("content-type"), + message_id=message.headers["message-id"], + correlation_id=cast("str", message.headers.get("correlation-id", gen_cor_id())), + ) + + +class StompPublishCommand(PublishCommand): + @classmethod + def from_cmd(cls, cmd: PublishCommand) -> Self: + return cmd # type: ignore[return-value] + + +@dataclass(kw_only=True) +class BrokerConfigWithStompClient(BrokerConfig): + client: stompman.Client + + +@dataclass(kw_only=True) +class _StompBaseSubscriberConfig: + destination_without_prefix: str + ack_mode: stompman.AckMode + headers: dict[str, str] | None + + +@dataclass(kw_only=True) +class StompSubscriberSpecificationConfig(_StompBaseSubscriberConfig, SubscriberSpecificationConfig): + parser: AsyncCallable = StompStreamMessage.from_frame + decoder: AsyncCallable = field(default=to_async(decode_message)) + + +@dataclass(kw_only=True) +class StompSubscriberUsecaseConfig(_StompBaseSubscriberConfig, SubscriberUsecaseConfig): + _outer_config: BrokerConfigWithStompClient + parser: AsyncCallable = StompStreamMessage.from_frame + decoder: AsyncCallable = field(default=to_async(decode_message)) + + @property + def ack_policy(self) -> AckPolicy: + return AckPolicy.MANUAL if self.ack_mode == "auto" else AckPolicy.NACK_ON_ERROR + + @property + def full_destination(self) -> str: + return self._outer_config.prefix + self.destination_without_prefix + + +@dataclass(kw_only=True) +class _StompBasePublisherConfig: + destination_without_prefix: str + + +@dataclass(kw_only=True) +class StompPublisherSpecificationConfig(_StompBasePublisherConfig, PublisherSpecificationConfig): ... + + +@dataclass(kw_only=True) +class StompPublisherUsecaseConfig(_StompBasePublisherConfig, PublisherUsecaseConfig): + _outer_config: BrokerConfigWithStompClient + + @property + def full_destination(self) -> str: + return self._outer_config.prefix + self.destination_without_prefix diff --git a/packages/faststream-stomp/faststream_stomp/opentelemetry.py b/packages/faststream-stomp/faststream_stomp/opentelemetry.py index db34e5d..86a1db4 100644 --- a/packages/faststream-stomp/faststream_stomp/opentelemetry.py +++ b/packages/faststream-stomp/faststream_stomp/opentelemetry.py @@ -1,21 +1,23 @@ import stompman -from faststream.broker.message import StreamMessage +from faststream import StreamMessage +from faststream._internal.basic_types import AnyDict from faststream.opentelemetry import TelemetrySettingsProvider from faststream.opentelemetry.consts import MESSAGING_DESTINATION_PUBLISH_NAME from faststream.opentelemetry.middleware import TelemetryMiddleware -from faststream.types import AnyDict from opentelemetry.metrics import Meter, MeterProvider -from opentelemetry.semconv._incubating.attributes import messaging_attributes # noqa: PLC2701 +from opentelemetry.semconv._incubating.attributes import messaging_attributes from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace import TracerProvider -from faststream_stomp.publisher import StompProducerPublishKwargs +from faststream_stomp.models import StompPublishCommand +__all__ = ["StompTelemetryMiddleware", "StompTelemetrySettingsProvider"] -class StompTelemetrySettingsProvider(TelemetrySettingsProvider[stompman.MessageFrame]): + +class StompTelemetrySettingsProvider(TelemetrySettingsProvider[stompman.MessageFrame, StompPublishCommand]): messaging_system = "stomp" - def get_consume_attrs_from_message(self, msg: StreamMessage[stompman.MessageFrame]) -> "AnyDict": + def get_consume_attrs_from_message(self, msg: StreamMessage[stompman.MessageFrame]) -> AnyDict: return { messaging_attributes.MESSAGING_SYSTEM: self.messaging_system, messaging_attributes.MESSAGING_MESSAGE_ID: msg.message_id, @@ -27,17 +29,17 @@ def get_consume_attrs_from_message(self, msg: StreamMessage[stompman.MessageFram def get_consume_destination_name(self, msg: StreamMessage[stompman.MessageFrame]) -> str: # noqa: PLR6301 return msg.raw_message.headers["destination"] - def get_publish_attrs_from_kwargs(self, kwargs: StompProducerPublishKwargs) -> AnyDict: # type: ignore[override] + def get_publish_attrs_from_cmd(self, cmd: StompPublishCommand) -> AnyDict: publish_attrs = { messaging_attributes.MESSAGING_SYSTEM: self.messaging_system, - messaging_attributes.MESSAGING_DESTINATION_NAME: kwargs["destination"], + messaging_attributes.MESSAGING_DESTINATION_NAME: cmd.destination, } - if kwargs["correlation_id"]: - publish_attrs[messaging_attributes.MESSAGING_MESSAGE_CONVERSATION_ID] = kwargs["correlation_id"] + if cmd.correlation_id: + publish_attrs[messaging_attributes.MESSAGING_MESSAGE_CONVERSATION_ID] = cmd.correlation_id return publish_attrs - def get_publish_destination_name(self, kwargs: StompProducerPublishKwargs) -> str: # type: ignore[override] # noqa: PLR6301 - return kwargs["destination"] + def get_publish_destination_name(self, cmd: StompPublishCommand) -> str: # noqa: PLR6301 + return cmd.destination class StompTelemetryMiddleware(TelemetryMiddleware): @@ -49,7 +51,7 @@ def __init__( meter: Meter | None = None, ) -> None: super().__init__( - settings_provider_factory=lambda _: StompTelemetrySettingsProvider(), + settings_provider_factory=lambda _: StompTelemetrySettingsProvider(), # type: ignore[arg-type,return-value] tracer_provider=tracer_provider, meter_provider=meter_provider, meter=meter, diff --git a/packages/faststream-stomp/faststream_stomp/prometheus.py b/packages/faststream-stomp/faststream_stomp/prometheus.py index 36606c5..aa3c7ab 100644 --- a/packages/faststream-stomp/faststream_stomp/prometheus.py +++ b/packages/faststream-stomp/faststream_stomp/prometheus.py @@ -3,22 +3,22 @@ from typing import TYPE_CHECKING import stompman +from faststream._internal.constants import EMPTY from faststream.prometheus import ConsumeAttrs, MetricsSettingsProvider -from faststream.prometheus.middleware import BasePrometheusMiddleware -from faststream.types import EMPTY +from faststream.prometheus.middleware import PrometheusMiddleware + +from faststream_stomp.models import StompPublishCommand if TYPE_CHECKING: from collections.abc import Sequence - from faststream.broker.message import StreamMessage + from faststream import StreamMessage from prometheus_client import CollectorRegistry - from faststream_stomp.publisher import StompProducerPublishKwargs - __all__ = ["StompMetricsSettingsProvider", "StompPrometheusMiddleware"] -class StompMetricsSettingsProvider(MetricsSettingsProvider[stompman.MessageFrame]): +class StompMetricsSettingsProvider(MetricsSettingsProvider[stompman.MessageFrame, StompPublishCommand]): messaging_system = "stomp" def get_consume_attrs_from_message(self, msg: StreamMessage[stompman.MessageFrame]) -> ConsumeAttrs: # noqa: PLR6301 @@ -28,11 +28,11 @@ def get_consume_attrs_from_message(self, msg: StreamMessage[stompman.MessageFram "messages_count": 1, } - def get_publish_destination_name_from_kwargs(self, kwargs: StompProducerPublishKwargs) -> str: # type: ignore[override] # noqa: PLR6301 - return kwargs["destination"] + def get_publish_destination_name_from_cmd(self, cmd: StompPublishCommand) -> str: # noqa: PLR6301 + return cmd.destination -class StompPrometheusMiddleware(BasePrometheusMiddleware): +class StompPrometheusMiddleware(PrometheusMiddleware[StompPublishCommand, stompman.MessageFrame]): def __init__( self, *, diff --git a/packages/faststream-stomp/faststream_stomp/publisher.py b/packages/faststream-stomp/faststream_stomp/publisher.py index 2f3c570..7c8115f 100644 --- a/packages/faststream-stomp/faststream_stomp/publisher.py +++ b/packages/faststream-stomp/faststream_stomp/publisher.py @@ -1,121 +1,108 @@ -from collections.abc import Sequence -from functools import partial -from itertools import chain -from typing import Any, TypedDict, Unpack +import typing +from typing import Any, NoReturn import stompman -from faststream.asyncapi.schema import Channel, CorrelationId, Message, Operation -from faststream.asyncapi.utils import resolve_payloads -from faststream.broker.message import encode_message -from faststream.broker.publisher.proto import ProducerProto -from faststream.broker.publisher.usecase import PublisherUsecase -from faststream.broker.types import AsyncCallable, BrokerMiddleware, PublisherMiddleware -from faststream.exceptions import NOT_CONNECTED_YET -from faststream.types import SendableMessage - - -class StompProducerPublishKwargs(TypedDict): - destination: str - correlation_id: str | None - headers: dict[str, str] | None - - -class StompProducer(ProducerProto): +from faststream import PublishCommand, PublishType +from faststream._internal.basic_types import SendableMessage +from faststream._internal.configs import BrokerConfig +from faststream._internal.endpoint.publisher import PublisherSpecification, PublisherUsecase +from faststream._internal.producer import ProducerProto +from faststream._internal.types import AsyncCallable, PublisherMiddleware +from faststream.message import encode_message +from faststream.specification.asyncapi.utils import resolve_payloads +from faststream.specification.schema import Message, Operation, PublisherSpec + +from faststream_stomp.models import ( + StompPublishCommand, + StompPublisherSpecificationConfig, + StompPublisherUsecaseConfig, +) + + +class StompProducer(ProducerProto[StompPublishCommand]): _parser: AsyncCallable _decoder: AsyncCallable def __init__(self, client: stompman.Client) -> None: self.client = client - async def publish(self, message: SendableMessage, **kwargs: Unpack[StompProducerPublishKwargs]) -> None: # type: ignore[override] - body, content_type = encode_message(message) - all_headers = kwargs["headers"].copy() if kwargs["headers"] else {} - if kwargs["correlation_id"]: - all_headers["correlation-id"] = kwargs["correlation_id"] - await self.client.send(body, kwargs["destination"], content_type=content_type, headers=all_headers) + async def publish(self, cmd: StompPublishCommand) -> None: + body, content_type = encode_message(cmd.body, serializer=None) + all_headers = cmd.headers.copy() if cmd.headers else {} + if cmd.correlation_id: + all_headers["correlation-id"] = cmd.correlation_id + await self.client.send(body, cmd.destination, content_type=content_type, headers=all_headers) - async def request( # type: ignore[override] - self, message: SendableMessage, *, correlation_id: str | None, headers: dict[str, str] | None - ) -> Any: # noqa: ANN401 + async def request(self, cmd: StompPublishCommand) -> NoReturn: msg = "`StompProducer` can be used only to publish a response for `reply-to` or `RPC` messages." raise NotImplementedError(msg) + async def publish_batch(self, cmd: StompPublishCommand) -> NoReturn: + raise NotImplementedError -class StompPublisher(PublisherUsecase[stompman.MessageFrame]): - _producer: StompProducer | None - def __init__( - self, - destination: str, - *, - broker_middlewares: Sequence[BrokerMiddleware[stompman.MessageFrame]], - middlewares: Sequence[PublisherMiddleware], - schema_: Any | None, # noqa: ANN401 - title_: str | None, - description_: str | None, - include_in_schema: bool, +class StompPublisherSpecification(PublisherSpecification[BrokerConfig, StompPublisherSpecificationConfig]): + @property + def name(self) -> str: + return f"{self._outer_config.prefix}{self.config.destination_without_prefix}:Publisher" + + def get_schema(self) -> dict[str, PublisherSpec]: + return { + self.name: PublisherSpec( + description=self.config.description_, + operation=Operation( + message=Message( + title=f"{self.name}:Message", payload=resolve_payloads(self.get_payloads(), "Publisher") + ), + bindings=None, + ), + bindings=None, + ) + } + + +class StompPublisher(PublisherUsecase): + def __init__(self, config: StompPublisherUsecaseConfig, specification: StompPublisherSpecification) -> None: + self.config = config + super().__init__(config=config, specification=specification) # type: ignore[arg-type] + + async def _publish( + self, cmd: PublishCommand, *, _extra_middlewares: typing.Iterable[PublisherMiddleware[PublishCommand]] ) -> None: - self.destination = destination - super().__init__( - broker_middlewares=broker_middlewares, - middlewares=middlewares, - schema_=schema_, - title_=title_, - description_=description_, - include_in_schema=include_in_schema, + publish_command = StompPublishCommand.from_cmd(cmd) + publish_command.destination = self.config.full_destination + return typing.cast( + "None", + await self._basic_publish( + publish_command, producer=self.config._outer_config.producer, _extra_middlewares=_extra_middlewares + ), ) - create = __init__ # type: ignore[assignment] - async def publish( - self, - message: SendableMessage, - *, - correlation_id: str | None = None, - headers: dict[str, str] | None = None, - _extra_middlewares: Sequence[PublisherMiddleware] = (), + self, message: SendableMessage, *, correlation_id: str | None = None, headers: dict[str, str] | None = None ) -> None: - assert self._producer, NOT_CONNECTED_YET # noqa: S101 - - call = self._producer.publish - for one_middleware in chain( - self._middlewares[::-1], # type: ignore[arg-type] - ( - _extra_middlewares # type: ignore[arg-type] - or (one_middleware(None).publish_scope for one_middleware in self._broker_middlewares[::-1]) + publish_command = StompPublishCommand( + message, + _publish_type=PublishType.PUBLISH, + destination=self.config.full_destination, + correlation_id=correlation_id, + headers=headers, + ) + return typing.cast( + "None", + await self._basic_publish( + publish_command, producer=self.config._outer_config.producer, _extra_middlewares=() ), - ): - call = partial(one_middleware, call) # type: ignore[operator, arg-type, misc] - - return await call(message, destination=self.destination, correlation_id=correlation_id, headers=headers or {}) + ) - async def request( # type: ignore[override] + async def request( self, message: SendableMessage, *, correlation_id: str | None = None, headers: dict[str, str] | None = None ) -> Any: # noqa: ANN401 - assert self._producer, NOT_CONNECTED_YET # noqa: S101 - return await self._producer.request(message, correlation_id=correlation_id, headers=headers) - - def __hash__(self) -> int: - return hash(f"publisher:{self.destination}") - - def get_name(self) -> str: - return f"{self.destination}:Publisher" - - def get_schema(self) -> dict[str, Channel]: - payloads = self.get_payloads() - - return { - self.name: Channel( - description=self.description, - publish=Operation( - message=Message( - title=f"{self.name}:Message", - payload=resolve_payloads(payloads, "Publisher"), - correlationId=CorrelationId(location="$message.header#/correlation_id"), - ), - ), - ) - } - - def add_prefix(self, prefix: str) -> None: - self.destination = f"{prefix}{self.destination}" + publish_command = StompPublishCommand( + message, + _publish_type=PublishType.REQUEST, + destination=self.config.full_destination, + correlation_id=correlation_id, + headers=headers, + ) + return await self._basic_request(publish_command, producer=self.config._outer_config.producer) diff --git a/packages/faststream-stomp/faststream_stomp/registrator.py b/packages/faststream-stomp/faststream_stomp/registrator.py index 1ba7327..0ea5335 100644 --- a/packages/faststream-stomp/faststream_stomp/registrator.py +++ b/packages/faststream-stomp/faststream_stomp/registrator.py @@ -1,18 +1,26 @@ from collections.abc import Iterable, Sequence -from typing import Any, cast +from typing import Any import stompman -from fast_depends.dependencies import Depends -from faststream.broker.core.abc import ABCBroker -from faststream.broker.types import CustomCallable, PublisherMiddleware, SubscriberMiddleware -from faststream.broker.utils import default_filter +from fast_depends.dependencies import Dependant +from faststream._internal.broker.registrator import Registrator +from faststream._internal.endpoint.subscriber.call_item import CallsCollection +from faststream._internal.types import CustomCallable, PublisherMiddleware, SubscriberMiddleware from typing_extensions import override -from faststream_stomp.publisher import StompPublisher -from faststream_stomp.subscriber import StompSubscriber +from faststream_stomp.models import ( + BrokerConfigWithStompClient, + StompPublishCommand, + StompPublisherSpecificationConfig, + StompPublisherUsecaseConfig, + StompSubscriberSpecificationConfig, + StompSubscriberUsecaseConfig, +) +from faststream_stomp.publisher import StompPublisher, StompPublisherSpecification +from faststream_stomp.subscriber import StompSubscriber, StompSubscriberSpecification -class StompRegistrator(ABCBroker[stompman.MessageFrame]): +class StompRegistrator(Registrator[stompman.MessageFrame, BrokerConfigWithStompClient]): @override def subscriber( # type: ignore[override] self, @@ -21,35 +29,36 @@ def subscriber( # type: ignore[override] ack_mode: stompman.AckMode = "client-individual", headers: dict[str, str] | None = None, # other args - dependencies: Iterable[Depends] = (), - no_ack: bool = False, + dependencies: Iterable[Dependant] = (), parser: CustomCallable | None = None, decoder: CustomCallable | None = None, middlewares: Sequence[SubscriberMiddleware[stompman.MessageFrame]] = (), - retry: bool = False, title: str | None = None, description: str | None = None, include_in_schema: bool = True, ) -> StompSubscriber: - subscriber = cast( - "StompSubscriber", - super().subscriber( - StompSubscriber( - destination=destination, - ack_mode=ack_mode, - headers=headers, - retry=retry, - no_ack=no_ack, - broker_middlewares=self._middlewares, - broker_dependencies=self._dependencies, - title_=title, - description_=description, - include_in_schema=self._solve_include_in_schema(include_in_schema), - ) + usecase_config = StompSubscriberUsecaseConfig( + _outer_config=self.config, # type: ignore[arg-type] + destination_without_prefix=destination, + ack_mode=ack_mode, + headers=headers, + ) + calls = CallsCollection[stompman.MessageFrame]() + specification = StompSubscriberSpecification( + _outer_config=self.config, # type: ignore[arg-type] + specification_config=StompSubscriberSpecificationConfig( + title_=title, + description_=description, + include_in_schema=include_in_schema, + destination_without_prefix=destination, + ack_mode=ack_mode, + headers=headers, ), + calls=calls, ) + subscriber = StompSubscriber(config=usecase_config, specification=specification, calls=calls) + super().subscriber(subscriber) return subscriber.add_call( - filter_=default_filter, parser_=parser or self._parser, decoder_=decoder or self._decoder, dependencies_=dependencies, @@ -61,23 +70,27 @@ def publisher( # type: ignore[override] self, destination: str, *, - middlewares: Sequence[PublisherMiddleware] = (), + middlewares: Sequence[PublisherMiddleware[StompPublishCommand]] = (), schema_: Any | None = None, title_: str | None = None, description_: str | None = None, include_in_schema: bool = True, ) -> StompPublisher: - return cast( - "StompPublisher", - super().publisher( - StompPublisher( - destination, - broker_middlewares=self._middlewares, - middlewares=middlewares, - schema_=schema_, - title_=title_, - description_=description_, - include_in_schema=include_in_schema, - ) + usecase_config = StompPublisherUsecaseConfig( + _outer_config=self.config, # type: ignore[arg-type] + middlewares=middlewares, + destination_without_prefix=destination, + ) + specification = StompPublisherSpecification( + _outer_config=self.config, # type: ignore[arg-type] + specification_config=StompPublisherSpecificationConfig( + title_=title_, + description_=description_, + schema_=schema_, + include_in_schema=include_in_schema, + destination_without_prefix=destination, ), ) + publisher = StompPublisher(config=usecase_config, specification=specification) + super().publisher(publisher) + return publisher diff --git a/packages/faststream-stomp/faststream_stomp/router.py b/packages/faststream-stomp/faststream_stomp/router.py index 812bded..6d7c0f7 100644 --- a/packages/faststream-stomp/faststream_stomp/router.py +++ b/packages/faststream-stomp/faststream_stomp/router.py @@ -2,11 +2,13 @@ from typing import Any import stompman -from fast_depends.dependencies import Depends -from faststream.broker.router import ArgsContainer, BrokerRouter, SubscriberRoute -from faststream.broker.types import BrokerMiddleware, CustomCallable, PublisherMiddleware, SubscriberMiddleware -from faststream.types import SendableMessage +from fast_depends.dependencies import Dependant +from faststream._internal.basic_types import SendableMessage +from faststream._internal.broker.router import ArgsContainer, BrokerRouter, SubscriberRoute +from faststream._internal.configs import BrokerConfig +from faststream._internal.types import BrokerMiddleware, CustomCallable, PublisherMiddleware, SubscriberMiddleware +from faststream_stomp.models import StompPublishCommand from faststream_stomp.registrator import StompRegistrator @@ -20,7 +22,7 @@ def __init__( self, destination: str, *, - middlewares: Sequence[PublisherMiddleware] = (), + middlewares: Sequence[PublisherMiddleware[StompPublishCommand]] = (), schema_: Any | None = None, # noqa: ANN401 title_: str | None = None, description_: str | None = None, @@ -51,12 +53,10 @@ def __init__( headers: dict[str, str] | None = None, # other args publishers: Iterable[StompRoutePublisher] = (), - dependencies: Iterable[Depends] = (), - no_ack: bool = False, + dependencies: Iterable[Dependant] = (), parser: CustomCallable | None = None, decoder: CustomCallable | None = None, middlewares: Sequence[SubscriberMiddleware[stompman.MessageFrame]] = (), - retry: bool = False, title: str | None = None, description: str | None = None, include_in_schema: bool = True, @@ -68,18 +68,16 @@ def __init__( headers=headers, publishers=publishers, dependencies=dependencies, - no_ack=no_ack, parser=parser, decoder=decoder, middlewares=middlewares, - retry=retry, title=title, description=description, include_in_schema=include_in_schema, ) -class StompRouter(StompRegistrator, BrokerRouter[stompman.MessageFrame]): +class StompRouter(StompRegistrator, BrokerRouter[stompman.MessageFrame, BrokerConfig]): """Includable to StompBroker router.""" def __init__( @@ -87,18 +85,22 @@ def __init__( prefix: str = "", handlers: Iterable[StompRoute] = (), *, - dependencies: Iterable[Depends] = (), + dependencies: Iterable[Dependant] = (), middlewares: Sequence[BrokerMiddleware[stompman.MessageFrame]] = (), parser: CustomCallable | None = None, decoder: CustomCallable | None = None, include_in_schema: bool | None = None, + routers: Sequence[StompRegistrator] = (), ) -> None: super().__init__( + config=BrokerConfig( + broker_middlewares=middlewares, + broker_dependencies=dependencies, + broker_parser=parser, + broker_decoder=decoder, + include_in_schema=include_in_schema, + prefix=prefix, + ), handlers=handlers, - prefix=prefix, - dependencies=dependencies, - middlewares=middlewares, - parser=parser, - decoder=decoder, - include_in_schema=include_in_schema, + routers=routers, # type: ignore[arg-type] ) diff --git a/packages/faststream-stomp/faststream_stomp/subscriber.py b/packages/faststream-stomp/faststream_stomp/subscriber.py index fb71cf9..52e73b5 100644 --- a/packages/faststream-stomp/faststream_stomp/subscriber.py +++ b/packages/faststream-stomp/faststream_stomp/subscriber.py @@ -1,142 +1,94 @@ -from collections.abc import Callable, Iterable, Sequence -from typing import Any, TypedDict, cast +import asyncio +from collections.abc import AsyncIterator, Sequence +from typing import Any, NoReturn import stompman -from fast_depends.dependencies import Depends -from faststream.asyncapi.schema import Channel, CorrelationId, Message, Operation -from faststream.asyncapi.utils import resolve_payloads -from faststream.broker.message import StreamMessage, decode_message -from faststream.broker.publisher.fake import FakePublisher -from faststream.broker.publisher.proto import ProducerProto -from faststream.broker.subscriber.usecase import SubscriberUsecase -from faststream.broker.types import AsyncCallable, BrokerMiddleware, CustomCallable -from faststream.types import AnyDict, Decorator, LoggerProto -from faststream.utils.functions import to_async +from faststream import PublishCommand, StreamMessage +from faststream._internal.configs import BrokerConfig +from faststream._internal.endpoint.publisher.fake import FakePublisher +from faststream._internal.endpoint.subscriber import SubscriberSpecification, SubscriberUsecase +from faststream._internal.endpoint.subscriber.call_item import CallsCollection +from faststream._internal.producer import ProducerProto +from faststream.specification.asyncapi.utils import resolve_payloads +from faststream.specification.schema import Message, Operation, SubscriberSpec + +from faststream_stomp.models import ( + StompPublishCommand, + StompSubscriberSpecificationConfig, + StompSubscriberUsecaseConfig, +) + + +class StompSubscriberSpecification(SubscriberSpecification[BrokerConfig, StompSubscriberSpecificationConfig]): + @property + def name(self) -> str: + return f"{self._outer_config.prefix}{self.config.destination_without_prefix}:{self.call_name}" + + def get_schema(self) -> dict[str, SubscriberSpec]: + return { + self.name: SubscriberSpec( + description=self.description, + operation=Operation( + message=Message(title=f"{self.name}:Message", payload=resolve_payloads(self.get_payloads())), + bindings=None, + ), + bindings=None, + ) + } -from faststream_stomp.message import StompStreamMessage +class StompFakePublisher(FakePublisher): + def __init__(self, *, producer: ProducerProto[Any], reply_to: str) -> None: + super().__init__(producer=producer) + self.reply_to = reply_to -class StompLogContext(TypedDict): - destination: str - message_id: str + def patch_command(self, cmd: PublishCommand | StompPublishCommand) -> StompPublishCommand: + cmd = super().patch_command(cmd) + real_cmd = StompPublishCommand.from_cmd(cmd) + real_cmd.destination = self.reply_to + return real_cmd class StompSubscriber(SubscriberUsecase[stompman.MessageFrame]): def __init__( self, *, - destination: str, - ack_mode: stompman.AckMode, - headers: dict[str, str] | None, - retry: bool | int, - no_ack: bool, - broker_dependencies: Iterable[Depends], - broker_middlewares: Sequence[BrokerMiddleware[stompman.MessageFrame]], - default_parser: AsyncCallable = StompStreamMessage.from_frame, - default_decoder: AsyncCallable = to_async(decode_message), # noqa: B008 - # AsyncAPI information - title_: str | None, - description_: str | None, - include_in_schema: bool, + config: StompSubscriberUsecaseConfig, + specification: StompSubscriberSpecification, + calls: CallsCollection[stompman.MessageFrame], ) -> None: - self.destination = destination - self.ack_mode = ack_mode - self.headers = headers + self.config = config self._subscription: stompman.ManualAckSubscription | None = None - - super().__init__( - no_ack=no_ack or self.ack_mode == "auto", - no_reply=True, - retry=retry, - broker_dependencies=broker_dependencies, - broker_middlewares=broker_middlewares, - default_parser=default_parser, - default_decoder=default_decoder, - title_=title_, - description_=description_, - include_in_schema=include_in_schema, - ) - - def setup( # type: ignore[override] - self, - client: stompman.Client, - *, - logger: LoggerProto | None, - producer: ProducerProto | None, - graceful_timeout: float | None, - extra_context: AnyDict, - broker_parser: CustomCallable | None, - broker_decoder: CustomCallable | None, - apply_types: bool, - is_validate: bool, - _get_dependant: Callable[..., Any] | None, - _call_decorators: Iterable[Decorator], - ) -> None: - self.client = client - return super().setup( - logger=logger, - producer=producer, - graceful_timeout=graceful_timeout, - extra_context=extra_context, - broker_parser=broker_parser, - broker_decoder=broker_decoder, - apply_types=apply_types, - is_validate=is_validate, - _get_dependant=_get_dependant, - _call_decorators=_call_decorators, - ) + super().__init__(config=config, specification=specification, calls=calls) # type: ignore[arg-type] async def start(self) -> None: await super().start() - self._subscription = await self.client.subscribe_with_manual_ack( - destination=self.destination, + self._subscription = await self.config._outer_config.client.subscribe_with_manual_ack( + destination=self.config.full_destination, handler=self.consume, - ack=self.ack_mode, - headers=self.headers, + ack=self.config.ack_mode, + headers=self.config.headers, ) + self._post_start() async def stop(self) -> None: if self._subscription: await self._subscription.unsubscribe() await super().stop() - async def get_one(self, *, timeout: float = 5) -> None: ... + async def get_one(self, *, timeout: float = 5) -> NoReturn: + raise NotImplementedError - def _make_response_publisher(self, message: StreamMessage[stompman.MessageFrame]) -> Sequence[FakePublisher]: - return ( # pragma: no cover - (FakePublisher(self._producer.publish, publish_kwargs={"destination": message.reply_to}),) - if self._producer - else () - ) - - def __hash__(self) -> int: - return hash(self.destination) - - def add_prefix(self, prefix: str) -> None: - self.destination = f"{prefix}{self.destination}" + async def __aiter__(self) -> AsyncIterator[StreamMessage[stompman.MessageFrame]]: # type: ignore[override, misc] + raise NotImplementedError + yield # pragma: no cover + await asyncio.sleep(0) # pragma: no cover - def get_name(self) -> str: - return f"{self.destination}:{self.call_name}" - - def get_schema(self) -> dict[str, Channel]: - payloads = self.get_payloads() - - return { - self.name: Channel( - description=self.description, - subscribe=Operation( - message=Message( - title=f"{self.name}:Message", - payload=resolve_payloads(payloads), - correlationId=CorrelationId(location="$message.header#/correlation_id"), - ), - ), - ) - } + def _make_response_publisher(self, message: StreamMessage[stompman.MessageFrame]) -> Sequence[FakePublisher]: + return (StompFakePublisher(producer=self.config._outer_config.producer, reply_to=message.reply_to),) def get_log_context(self, message: StreamMessage[stompman.MessageFrame] | None) -> dict[str, str]: - log_context: StompLogContext = { - "destination": message.raw_message.headers["destination"] if message else self.destination, + return { + "destination": message.raw_message.headers["destination"] if message else self.config.full_destination, "message_id": message.message_id if message else "", } - return cast("dict[str, str]", log_context) diff --git a/packages/faststream-stomp/faststream_stomp/testing.py b/packages/faststream-stomp/faststream_stomp/testing.py index 3cb5da3..efbe540 100644 --- a/packages/faststream-stomp/faststream_stomp/testing.py +++ b/packages/faststream-stomp/faststream_stomp/testing.py @@ -1,14 +1,16 @@ import uuid +from collections.abc import Generator, Iterator +from contextlib import contextmanager from typing import TYPE_CHECKING, Any from unittest import mock from unittest.mock import AsyncMock import stompman -from faststream.broker.message import encode_message -from faststream.testing.broker import TestBroker -from faststream.types import SendableMessage +from faststream._internal.testing.broker import TestBroker, change_producer +from faststream.message import encode_message from faststream_stomp.broker import StompBroker +from faststream_stomp.models import StompPublishCommand from faststream_stomp.publisher import StompProducer, StompPublisher from faststream_stomp.subscriber import StompSubscriber @@ -22,27 +24,31 @@ def create_publisher_fake_subscriber( broker: StompBroker, publisher: StompPublisher ) -> tuple[StompSubscriber, bool]: subscriber: StompSubscriber | None = None - for handler in broker._subscribers.values(): # noqa: SLF001 - if handler.destination == publisher.destination: + for handler in broker._subscribers: + if handler.config.full_destination == publisher.config.full_destination: subscriber = handler break if subscriber is None: is_real = False - subscriber = broker.subscriber(publisher.destination) + subscriber = broker.subscriber(publisher.config.full_destination) else: is_real = True return subscriber, is_real + @contextmanager + def _patch_producer(self, broker: StompBroker) -> Iterator[None]: # noqa: PLR6301 + with change_producer(broker.config.broker_config, FakeStompProducer(broker)): + yield + + @contextmanager + def _patch_broker(self, broker: StompBroker) -> Generator[None, None, None]: + with mock.patch.object(broker.config, "client", new_callable=AsyncMock), super()._patch_broker(broker): + yield + @staticmethod - async def _fake_connect( - broker: StompBroker, - *args: Any, # noqa: ANN401, ARG004 - **kwargs: Any, # noqa: ANN401, ARG004 - ) -> None: - broker._connection = AsyncMock() # noqa: SLF001 - broker._producer = FakeStompProducer(broker) # noqa: SLF001 + async def _fake_connect(broker: StompBroker, *args: Any, **kwargs: Any) -> None: ... # noqa: ANN401 class FakeAckableMessageFrame(stompman.AckableMessageFrame): @@ -55,26 +61,18 @@ class FakeStompProducer(StompProducer): def __init__(self, broker: StompBroker) -> None: self.broker = broker - async def publish( # type: ignore[override] - self, - message: SendableMessage, - *, - destination: str, - correlation_id: str | None, - headers: dict[str, str] | None, - ) -> None: - body, content_type = encode_message(message) - all_headers: MessageHeaders = (headers.copy() if headers else {}) | { # type: ignore[assignment] - "destination": destination, + async def publish(self, cmd: StompPublishCommand) -> None: + body, content_type = encode_message(cmd.body, serializer=None) + all_headers: MessageHeaders = (cmd.headers.copy() if cmd.headers else {}) | { # type: ignore[assignment] + "destination": cmd.destination, "message-id": str(uuid.uuid4()), "subscription": str(uuid.uuid4()), } - if correlation_id: - all_headers["correlation-id"] = correlation_id # type: ignore[typeddict-unknown-key] + if cmd.correlation_id: + all_headers["correlation-id"] = cmd.correlation_id # type: ignore[typeddict-unknown-key] if content_type: all_headers["content-type"] = content_type frame = FakeAckableMessageFrame(headers=all_headers, body=body, _subscription=mock.AsyncMock()) - - for handler in self.broker._subscribers.values(): # noqa: SLF001 - if handler.destination == destination: + for handler in self.broker._subscribers: + if handler.config.full_destination == cmd.destination: await handler.process_message(frame) diff --git a/packages/faststream-stomp/pyproject.toml b/packages/faststream-stomp/pyproject.toml index b13182e..ba6b402 100644 --- a/packages/faststream-stomp/pyproject.toml +++ b/packages/faststream-stomp/pyproject.toml @@ -24,8 +24,10 @@ requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [dependency-groups] -dev = ["faststream[otel,prometheus]~=0.5", "asgi-lifespan"] +dev = ["faststream[otel,prometheus]~=0.6", "asgi-lifespan"] +[tool.uv.sources] +faststream = {git="https://github.com/ag2ai/faststream", branch="0.6.0"} [tool.hatch.version] source = "vcs" raw-options.root = "../.." diff --git a/packages/faststream-stomp/test_faststream_stomp/test_integration.py b/packages/faststream-stomp/test_faststream_stomp/test_integration.py index 8e6da98..ca35ef9 100644 --- a/packages/faststream-stomp/test_faststream_stomp/test_integration.py +++ b/packages/faststream-stomp/test_faststream_stomp/test_integration.py @@ -11,10 +11,10 @@ from asgi_lifespan import LifespanManager from faststream import BaseMiddleware, Context, FastStream from faststream.asgi import AsgiFastStream -from faststream.broker.message import gen_cor_id -from faststream.broker.middlewares.logging import CriticalLogMiddleware from faststream.exceptions import AckMessage, NackMessage, RejectMessage -from faststream_stomp.message import StompStreamMessage +from faststream.message import gen_cor_id +from faststream_stomp.models import StompStreamMessage +from faststream_stomp.router import StompRoutePublisher if TYPE_CHECKING: from faststream_stomp.broker import StompBroker @@ -80,11 +80,15 @@ async def _() -> None: async def test_router(faker: faker.Faker, broker: faststream_stomp.StompBroker) -> None: expected_body, prefix, destination = faker.pystr(), faker.pystr(), faker.pystr() - def route(body: str, message: stompman.MessageFrame = Context("message.raw_message")) -> None: # noqa: B008 + def route(body: str, message: stompman.MessageFrame = Context("message.raw_message")) -> bytes: # noqa: B008 assert body == expected_body event.set() + return message.body - router = faststream_stomp.StompRouter(prefix=prefix, handlers=(faststream_stomp.StompRoute(route, destination),)) + router = faststream_stomp.StompRouter( + prefix=prefix, + handlers=[faststream_stomp.StompRoute(route, destination, publishers=[StompRoutePublisher(faker.pystr())])], + ) publisher = router.publisher(destination) broker.include_router(router) @@ -168,8 +172,7 @@ async def test_ok( ) -> None: monkeypatch.delenv("PYTEST_CURRENT_TEST") broker: StompBroker = request.getfixturevalue("broker") - assert broker.logger - broker.logger = mock.Mock(log=(log_mock := mock.Mock()), handlers=[]) + broker.config.logger.logger = mock.Mock(log=(log_mock := mock.Mock()), handlers=[]) @broker.subscriber(destination := faker.pystr()) def some_handler() -> None: ... @@ -191,9 +194,7 @@ async def test_raises( ) -> None: monkeypatch.delenv("PYTEST_CURRENT_TEST") broker: StompBroker = request.getfixturevalue("broker") - assert isinstance(broker._middlewares[0], CriticalLogMiddleware) - assert broker._middlewares[0].logger - broker._middlewares[0].logger = mock.Mock(log=(log_mock := mock.Mock())) + broker.config.broker_config.logger = mock.Mock(log=(log_mock := mock.Mock())) event = asyncio.Event() message_id: str | None = None @@ -214,10 +215,10 @@ def some_handler(message_frame: Annotated[stompman.MessageFrame, Context("messag assert message_id extra = {"destination": destination, "message_id": message_id} - assert log_mock.mock_calls == [ - mock.call(logging.INFO, "Received", extra=extra), - mock.call(logging.ERROR, "MyError: ", extra=extra, exc_info=MyError()), - mock.call(logging.INFO, "Processed", extra=extra), + assert log_mock.mock_calls[-3:] == [ + mock.call("Received", extra=extra), + mock.call(message="MyError: ", extra=extra, exc_info=MyError()), + mock.call(message="Processed", extra=extra), ] diff --git a/packages/faststream-stomp/test_faststream_stomp/test_main.py b/packages/faststream-stomp/test_faststream_stomp/test_main.py index 4417225..a4f7e17 100644 --- a/packages/faststream-stomp/test_faststream_stomp/test_main.py +++ b/packages/faststream-stomp/test_faststream_stomp/test_main.py @@ -1,12 +1,9 @@ -from unittest import mock - import faker import faststream_stomp import pytest import stompman from faststream import FastStream -from faststream.asyncapi import get_app_schema -from faststream.broker.message import gen_cor_id +from faststream.message import gen_cor_id from faststream_stomp.opentelemetry import StompTelemetryMiddleware from faststream_stomp.prometheus import StompPrometheusMiddleware from opentelemetry.sdk.metrics import MeterProvider @@ -59,33 +56,42 @@ def second_handle(body: str) -> None: third_publisher.mock.assert_called_once_with(expected_body) -async def test_broker_request_not_implemented(faker: faker.Faker, broker: faststream_stomp.StompBroker) -> None: - async with faststream_stomp.TestStompBroker(broker): - with pytest.raises(NotImplementedError): - await broker.request(faker.pystr()) +class TestNotImplemented: + async def test_broker_request(self, faker: faker.Faker, broker: faststream_stomp.StompBroker) -> None: + async with faststream_stomp.TestStompBroker(broker): + with pytest.raises(NotImplementedError): + await broker.request(faker.pystr(), faker.pystr()) + async def test_broker_publish_batch(self, faker: faker.Faker, broker: faststream_stomp.StompBroker) -> None: + async with faststream_stomp.TestStompBroker(broker): + with pytest.raises(NotImplementedError): + await broker.publish_batch(faker.pystr(), destination=faker.pystr()) -async def test_publisher_request_not_implemented(faker: faker.Faker, broker: faststream_stomp.StompBroker) -> None: - async with faststream_stomp.TestStompBroker(broker): - with pytest.raises(NotImplementedError): - await broker.publisher(faker.pystr()).request(faker.pystr()) + async def test_publisher_request(self, faker: faker.Faker, broker: faststream_stomp.StompBroker) -> None: + async with faststream_stomp.TestStompBroker(broker): + with pytest.raises(NotImplementedError): + await broker.publisher(faker.pystr()).request(faker.pystr()) + async def test_subscriber_get_one(self, faker: faker.Faker, broker: faststream_stomp.StompBroker) -> None: + async with faststream_stomp.TestStompBroker(broker): + with pytest.raises(NotImplementedError): + await broker.subscriber(faker.pystr()).get_one() -def test_get_fmt(broker: faststream_stomp.StompBroker) -> None: - broker.get_fmt() + async def test_subscriber_aiter(self, faker: faker.Faker, broker: faststream_stomp.StompBroker) -> None: + async with faststream_stomp.TestStompBroker(broker): + with pytest.raises(NotImplementedError): + async for _ in broker.subscriber(faker.pystr()): + ... # pragma: no cover def test_asyncapi_schema(faker: faker.Faker, broker: faststream_stomp.StompBroker) -> None: - broker.include_router( - faststream_stomp.StompRouter( - handlers=( - faststream_stomp.StompRoute( - mock.Mock(), faker.pystr(), publishers=(faststream_stomp.StompRoutePublisher(faker.pystr()),) - ), - ) - ) - ) - get_app_schema(FastStream(broker)) + @broker.publisher(faker.pystr()) + def _publisher() -> None: ... + + @broker.subscriber(faker.pystr()) + def _subscriber() -> None: ... + + FastStream(broker).schema.to_specification() async def test_opentelemetry_publish(faker: faker.Faker, broker: faststream_stomp.StompBroker) -> None: diff --git a/packages/stompman/stompman/client.py b/packages/stompman/stompman/client.py index 3915e03..b34b098 100644 --- a/packages/stompman/stompman/client.py +++ b/packages/stompman/stompman/client.py @@ -49,8 +49,8 @@ class Client: _active_subscriptions: ActiveSubscriptions = field(default_factory=ActiveSubscriptions, init=False) _active_transactions: set[Transaction] = field(default_factory=set, init=False) _exit_stack: AsyncExitStack = field(default_factory=AsyncExitStack, init=False) - _listen_task: asyncio.Task[None] = field(init=False) - _task_group: asyncio.TaskGroup = field(init=False) + _listen_task: asyncio.Task[None] = field(init=False, repr=False) + _task_group: asyncio.TaskGroup = field(init=False, repr=False) def __post_init__(self) -> None: self._connection_manager = ConnectionManager( @@ -98,7 +98,7 @@ async def _listen_to_frames(self) -> None: case MessageFrame(): if subscription := self._active_subscriptions.get_by_id(frame.headers["subscription"]): task_group.create_task( - subscription._run_handler(frame=frame) # noqa: SLF001 + subscription._run_handler(frame=frame) if isinstance(subscription, AutoAckSubscription) else subscription.handler( AckableMessageFrame( @@ -159,7 +159,7 @@ async def subscribe( _connection_manager=self._connection_manager, _active_subscriptions=self._active_subscriptions, ) - await subscription._subscribe() # noqa: SLF001 + await subscription._subscribe() return subscription async def subscribe_with_manual_ack( @@ -178,10 +178,10 @@ async def subscribe_with_manual_ack( _connection_manager=self._connection_manager, _active_subscriptions=self._active_subscriptions, ) - await subscription._subscribe() # noqa: SLF001 + await subscription._subscribe() return subscription def is_alive(self) -> bool: return ( - self._connection_manager._active_connection_state or False # noqa: SLF001 - ) and self._connection_manager._active_connection_state.is_alive(self.check_server_alive_interval_factor) # noqa: SLF001 + self._connection_manager._active_connection_state or False + ) and self._connection_manager._active_connection_state.is_alive(self.check_server_alive_interval_factor) diff --git a/packages/stompman/stompman/connection_manager.py b/packages/stompman/stompman/connection_manager.py index 2bad19f..360f470 100644 --- a/packages/stompman/stompman/connection_manager.py +++ b/packages/stompman/stompman/connection_manager.py @@ -102,7 +102,7 @@ async def _check_server_heartbeat_forever(self, receive_heartbeat_interval_ms: i if not self._active_connection_state: continue if not self._active_connection_state.is_alive(self.check_server_alive_interval_factor): - self._active_connection_state = None + self._clear_active_connection_state() async def _create_connection_to_one_server( self, server: ConnectionParameters diff --git a/packages/stompman/stompman/subscription.py b/packages/stompman/stompman/subscription.py index 3d4af14..07869fb 100644 --- a/packages/stompman/stompman/subscription.py +++ b/packages/stompman/stompman/subscription.py @@ -112,10 +112,10 @@ class AckableMessageFrame(MessageFrame): _subscription: ManualAckSubscription async def ack(self) -> None: - await self._subscription._ack(self) # noqa: SLF001 + await self._subscription._ack(self) async def nack(self) -> None: - await self._subscription._nack(self) # noqa: SLF001 + await self._subscription._nack(self) def _make_subscription_id() -> str: diff --git a/pyproject.toml b/pyproject.toml index ef773f8..1f247df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dev = [ "polyfactory==2.19.0", "pytest==8.3.4", "pytest-cov==6.0.0", + "pytest-timeout==2.4.0", "ruff==0.9.3", "uvloop==0.21.0", ] @@ -49,8 +50,10 @@ ignore = [ "DOC201", "DOC501", "ISC001", + "PLC2701", "PLC2801", "PLR0913", + "SLF001" ] extend-per-file-ignores = { "*/test_*/*" = ["S101", "SLF001", "ARG", "PLR6301"] }