11import asyncio
22import logging
33import types
4- from collections .abc import Callable , Iterable , Mapping , Sequence
5- from typing import Any , Unpack
4+ import typing
5+ from collections .abc import Iterable , Sequence
6+ from typing import Any
67
78import anyio
89import stompman
9- from fast_depends .dependencies import Depends
10- from faststream .asyncapi .schema import Tag , TagDict
11- from faststream .broker .core .usecase import BrokerUsecase
12- from faststream .broker .types import BrokerMiddleware , CustomCallable
13- from faststream .log .logging import get_broker_logger
10+ from fast_depends .dependencies import Dependant
11+ from faststream import ContextRepo , PublishType
12+ from faststream ._internal .basic_types import LoggerProto , SendableMessage
13+ from faststream ._internal .broker import BrokerUsecase
14+ from faststream ._internal .broker .registrator import Registrator
15+ from faststream ._internal .configs import BrokerConfig
16+ from faststream ._internal .constants import EMPTY
17+ from faststream ._internal .di import FastDependsConfig
18+ from faststream ._internal .logger import DefaultLoggerStorage , make_logger_state
19+ from faststream ._internal .logger .logging import get_broker_logger
20+ from faststream ._internal .types import BrokerMiddleware , CustomCallable
1421from faststream .security import BaseSecurity
15- from faststream .types import EMPTY , AnyDict , Decorator , LoggerProto , SendableMessage
22+ from faststream .specification .schema import BrokerSpec
23+ from faststream .specification .schema .extra import Tag , TagDict
1624
25+ from faststream_stomp .models import BrokerConfigWithStompClient , StompPublishCommand
1726from faststream_stomp .publisher import StompProducer , StompPublisher
1827from faststream_stomp .registrator import StompRegistrator
19- from faststream_stomp .subscriber import StompLogContext , StompSubscriber
28+ from faststream_stomp .subscriber import StompSubscriber
2029
2130
2231class StompSecurity (BaseSecurity ):
2332 def __init__ (self ) -> None :
2433 self .ssl_context = None
2534 self .use_ssl = False
2635
27- def get_requirement (self ) -> list [AnyDict ]: # noqa: PLR6301
36+ def get_requirement (self ) -> list [dict [ str , Any ] ]: # noqa: PLR6301
2837 return [{"user-password" : []}]
2938
3039 def get_schema (self ) -> dict [str , dict [str , str ]]: # noqa: PLR6301
@@ -43,83 +52,111 @@ def _handle_listen_task_done(listen_task: asyncio.Task[None]) -> None:
4352 raise SystemExit (1 )
4453
4554
46- class StompBroker (StompRegistrator , BrokerUsecase [stompman .MessageFrame , stompman .Client ]):
47- _subscribers : Mapping [int , StompSubscriber ]
48- _publishers : Mapping [int , StompPublisher ]
55+ class StompParamsStorage (DefaultLoggerStorage ):
4956 __max_msg_id_ln = 10
5057 _max_channel_name = 4
5158
59+ def get_logger (self , * , context : ContextRepo ) -> LoggerProto :
60+ if logger := self ._get_logger_ref ():
61+ return logger
62+ logger = get_broker_logger (
63+ name = "stomp" ,
64+ default_context = {"destination" : "" , "message_id" : "" },
65+ message_id_ln = self .__max_msg_id_ln ,
66+ fmt = (
67+ "%(asctime)s %(levelname)-8s - "
68+ f"%(destination)-{ self ._max_channel_name } s | "
69+ f"%(message_id)-{ self .__max_msg_id_ln } s "
70+ "- %(message)s"
71+ ),
72+ context = context ,
73+ log_level = self .logger_log_level ,
74+ )
75+ self ._logger_ref .add (logger )
76+ return logger
77+
78+
79+ class StompBroker (
80+ StompRegistrator ,
81+ BrokerUsecase [
82+ stompman .MessageFrame ,
83+ stompman .Client ,
84+ BrokerConfig , # Using BrokerConfig to avoid typing issues when passing broker to FastStream app
85+ ],
86+ ):
87+ _subscribers : list [StompSubscriber ] # type: ignore[assignment]
88+ _publishers : list [StompPublisher ] # type: ignore[assignment]
89+
5290 def __init__ (
5391 self ,
5492 client : stompman .Client ,
5593 * ,
5694 decoder : CustomCallable | None = None ,
5795 parser : CustomCallable | None = None ,
58- dependencies : Iterable [Depends ] = (),
59- middlewares : Sequence [BrokerMiddleware [stompman .MessageFrame ]] = (),
96+ dependencies : Iterable [Dependant ] = (),
97+ middlewares : Sequence [BrokerMiddleware [stompman .MessageFrame , StompPublishCommand ]] = (),
6098 graceful_timeout : float | None = 15.0 ,
99+ routers : Sequence [Registrator [stompman .MessageFrame ]] = (),
61100 # Logging args
62101 logger : LoggerProto | None = EMPTY ,
63102 log_level : int = logging .INFO ,
64103 # FastDepends args
65104 apply_types : bool = True ,
66- validate : bool = True ,
67- _get_dependant : Callable [..., Any ] | None = None ,
68- _call_decorators : Iterable [Decorator ] = (),
69- # AsyncAPI kwargs,
105+ # AsyncAPI args
70106 description : str | None = None ,
71- tags : Iterable [Tag | TagDict ] | None = None ,
107+ tags : Iterable [Tag | TagDict ] = () ,
72108 ) -> None :
73- super ().__init__ (
74- client = client , # **connection_kwargs
75- decoder = decoder ,
76- parser = parser ,
77- dependencies = dependencies ,
78- middlewares = middlewares ,
109+ broker_config = BrokerConfigWithStompClient (
110+ broker_middlewares = middlewares , # type: ignore[arg-type]
111+ broker_parser = parser ,
112+ broker_decoder = decoder ,
113+ logger = make_logger_state (
114+ logger = logger ,
115+ log_level = log_level ,
116+ default_storage_cls = StompParamsStorage , # type: ignore[type-abstract]
117+ ),
118+ fd_config = FastDependsConfig (use_fastdepends = apply_types ),
119+ broker_dependencies = dependencies ,
79120 graceful_timeout = graceful_timeout ,
80- logger = logger ,
81- log_level = log_level ,
82- apply_types = apply_types ,
83- validate = validate ,
84- _get_dependant = _get_dependant ,
85- _call_decorators = _call_decorators ,
121+ extra_context = { "broker" : self } ,
122+ producer = StompProducer ( client ) ,
123+ client = client ,
124+ )
125+ specification = BrokerSpec (
126+ url = [ f" { one_server . host } : { one_server . port } " for one_server in broker_config . client . servers ] ,
86127 protocol = "STOMP" ,
87128 protocol_version = "1.2" ,
88129 description = description ,
89130 tags = tags ,
90- asyncapi_url = [f"{ one_server .host } :{ one_server .port } " for one_server in client .servers ],
91131 security = StompSecurity (),
92- default_logger = get_broker_logger (
93- name = "stomp" , default_context = {"channel" : "" }, message_id_ln = self .__max_msg_id_ln
94- ),
95132 )
96- self ._attempted_to_connect = False
97-
98- async def start (self ) -> None :
99- await super ().start ()
100133
101- for handler in self ._subscribers .values ():
102- self ._log (f"`{ handler .call_name } ` waiting for messages" , extra = handler .get_log_context (None ))
103- await handler .start ()
134+ super ().__init__ (config = broker_config , specification = specification , routers = routers )
135+ self ._attempted_to_connect = False
104136
105- async def _connect (self , client : stompman . Client ) -> stompman .Client : # type: ignore[override]
137+ async def _connect (self ) -> stompman .Client :
106138 if self ._attempted_to_connect :
107- return client
139+ return self . config . broker_config . client
108140 self ._attempted_to_connect = True
109- self ._producer = StompProducer (client )
110- await client .__aenter__ ()
111- client ._listen_task .add_done_callback (_handle_listen_task_done ) # noqa: SLF001
112- return client
141+ await self .config .broker_config .client .__aenter__ ()
142+ self .config .broker_config .client ._listen_task .add_done_callback (_handle_listen_task_done )
143+ return self .config .broker_config .client
144+
145+ async def start (self ) -> None :
146+ await self .connect ()
147+ await super ().start ()
113148
114- async def _close (
149+ async def stop (
115150 self ,
116151 exc_type : type [BaseException ] | None = None ,
117152 exc_val : BaseException | None = None ,
118153 exc_tb : types .TracebackType | None = None ,
119154 ) -> None :
155+ for sub in self .subscribers :
156+ await sub .stop ()
120157 if self ._connection :
121158 await self ._connection .__aexit__ (exc_type , exc_val , exc_tb )
122- return await super (). _close ( exc_type , exc_val , exc_tb )
159+ self . running = False
123160
124161 async def ping (self , timeout : float | None = None ) -> bool :
125162 sleep_time = (timeout or 10 ) / 10
@@ -138,42 +175,52 @@ async def ping(self, timeout: float | None = None) -> bool:
138175
139176 return False # pragma: no cover
140177
141- def get_fmt (self ) -> str :
142- # `StompLogContext`
143- return (
144- "%(asctime)s %(levelname)-8s - "
145- f"%(destination)-{ self ._max_channel_name } s | "
146- f"%(message_id)-{ self .__max_msg_id_ln } s "
147- "- %(message)s"
148- )
149-
150- def _setup_log_context (self , ** log_context : Unpack [StompLogContext ]) -> None : ... # type: ignore[override]
151-
152- @property
153- def _subscriber_setup_extra (self ) -> "AnyDict" :
154- return {** super ()._subscriber_setup_extra , "client" : self ._connection }
155-
156- async def publish ( # type: ignore[override]
178+ async def publish (
157179 self ,
158180 message : SendableMessage ,
159181 destination : str ,
160182 * ,
161183 correlation_id : str | None = None ,
162184 headers : dict [str , str ] | None = None ,
163185 ) -> None :
164- await super (). publish (
186+ publish_command = StompPublishCommand (
165187 message ,
166- producer = self ._producer ,
167- correlation_id = correlation_id ,
188+ _publish_type = PublishType .PUBLISH ,
168189 destination = destination ,
190+ correlation_id = correlation_id ,
169191 headers = headers ,
170192 )
193+ return typing .cast ("None" , await self ._basic_publish (publish_command , producer = self .config .producer ))
171194
172195 async def request ( # type: ignore[override]
173196 self ,
174- msg : Any , # noqa: ANN401
197+ message : SendableMessage ,
198+ destination : str ,
175199 * ,
176200 correlation_id : str | None = None ,
177201 headers : dict [str , str ] | None = None ,
178202 ) -> Any : # noqa: ANN401
179- return await super ().request (msg , producer = self ._producer , correlation_id = correlation_id , headers = headers )
203+ publish_command = StompPublishCommand (
204+ message ,
205+ _publish_type = PublishType .REQUEST ,
206+ destination = destination ,
207+ correlation_id = correlation_id ,
208+ headers = headers ,
209+ )
210+ return await self ._basic_request (publish_command , producer = self .config .producer )
211+
212+ async def publish_batch ( # type: ignore[override]
213+ self ,
214+ * _messages : SendableMessage ,
215+ destination : str ,
216+ correlation_id : str | None = None ,
217+ headers : dict [str , str ] | None = None ,
218+ ) -> None :
219+ publish_command = StompPublishCommand (
220+ "" ,
221+ _publish_type = PublishType .PUBLISH ,
222+ destination = destination ,
223+ correlation_id = correlation_id ,
224+ headers = headers ,
225+ )
226+ return typing .cast ("None" , await self ._basic_publish_batch (publish_command , producer = self .config .producer ))
0 commit comments