-
Notifications
You must be signed in to change notification settings - Fork 51
Support multiple quorums on a single LighthouseServer using gRPC metadata-based room assignment #189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Support multiple quorums on a single LighthouseServer using gRPC metadata-based room assignment #189
Changes from 2 commits
fedd473
eb482e5
5ab4c0c
0a9ce34
273d3ee
8aed1fc
53ec8be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,8 +8,11 @@ pub mod lighthouse; | |
| pub mod manager; | ||
| mod net; | ||
| mod retry; | ||
| mod router; | ||
| mod timeout; | ||
|
|
||
| pub use crate::router::Router; | ||
|
|
||
| use anyhow::Result; | ||
| use atty::Stream; | ||
| use core::time::Duration; | ||
|
|
@@ -21,6 +24,7 @@ use std::thread::available_parallelism; | |
| use structopt::StructOpt; | ||
| use tokio::runtime::Runtime; | ||
| use tokio::task::JoinHandle; | ||
| use tokio_stream::wrappers::TcpListenerStream; | ||
| use tonic::transport::Channel; | ||
| use tonic::Status; | ||
|
|
||
|
|
@@ -33,7 +37,9 @@ pub mod torchftpb { | |
| } | ||
|
|
||
| use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; | ||
| use crate::torchftpb::lighthouse_service_server::LighthouseServiceServer; | ||
| use crate::torchftpb::manager_service_client::ManagerServiceClient; | ||
| use crate::torchftpb::LighthouseHeartbeatRequest; | ||
| use crate::torchftpb::{ | ||
| CheckpointMetadataRequest, LighthouseHeartbeatRequest, LighthouseQuorumRequest, | ||
| ManagerQuorumRequest, ShouldCommitRequest, | ||
|
|
@@ -339,9 +345,12 @@ fn lighthouse_main(py: Python<'_>) -> PyResult<()> { | |
| } | ||
|
|
||
| async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> { | ||
| let lighthouse = lighthouse::Lighthouse::new(opt).await?; | ||
| let router = Router::new(opt.clone()); | ||
|
|
||
| lighthouse.run().await?; | ||
| tonic::transport::Server::builder() | ||
| .add_service(LighthouseServiceServer::new(router)) | ||
| .serve(opt.bind.parse::<std::net::SocketAddr>()?) | ||
| .await?; | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
@@ -479,13 +488,19 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult<Quorum> { | |
| struct LighthouseClient { | ||
| client: LighthouseServiceClient<Channel>, | ||
| runtime: Runtime, | ||
| room_id: Option<String>, | ||
| } | ||
|
|
||
| #[pymethods] | ||
| impl LighthouseClient { | ||
| #[pyo3(signature = (addr, connect_timeout))] | ||
| #[pyo3(signature = (addr, connect_timeout, room_id = None))] | ||
| #[new] | ||
| fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult<Self> { | ||
| fn new( | ||
| py: Python<'_>, | ||
| addr: String, | ||
| connect_timeout: Duration, | ||
| room_id: Option<String>, | ||
| ) -> PyResult<Self> { | ||
| py.allow_threads(move || { | ||
| let runtime = tokio::runtime::Builder::new_multi_thread() | ||
| .worker_threads(num_threads()) | ||
|
|
@@ -498,6 +513,7 @@ impl LighthouseClient { | |
| Ok(Self { | ||
| client: client, | ||
| runtime: runtime, | ||
| room_id: room_id, | ||
| }) | ||
| }) | ||
| } | ||
|
|
@@ -553,6 +569,8 @@ impl LighthouseClient { | |
| }), | ||
| }); | ||
|
|
||
| let mut request = self.add_room_header(request); | ||
|
|
||
| // This timeout is processed on the server side so we also enable | ||
| // keep alives to detect server health. | ||
| request.set_timeout(timeout); | ||
|
|
@@ -581,13 +599,29 @@ impl LighthouseClient { | |
| ) -> Result<(), StatusError> { | ||
| py.allow_threads(move || { | ||
| let mut req = tonic::Request::new(LighthouseHeartbeatRequest { replica_id }); | ||
| let mut req = self.add_room_header(req); | ||
| req.set_timeout(timeout); | ||
| self.runtime.block_on(self.client.clone().heartbeat(req))?; | ||
| Ok(()) | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| impl LighthouseClient { | ||
| /// Attach `"room-id"` header if `self.room_id` is Some(_) | ||
| fn add_room_header<T>(&self, mut req: tonic::Request<T>) -> tonic::Request<T> { | ||
| if let Some(ref id) = self.room_id { | ||
| use tonic::metadata::MetadataValue; | ||
| req.metadata_mut().insert( | ||
| crate::router::ROOM_ID_HEADER, | ||
| MetadataValue::try_from(id.as_str()).expect("room-id ascii"), | ||
| ); | ||
| } | ||
| req | ||
| } | ||
|
|
||
| } | ||
|
|
||
| /// LighthouseServer is a GRPC server for the lighthouse service. | ||
| /// | ||
| /// It is used to coordinate the ManagerServer for each replica group. | ||
|
|
@@ -603,7 +637,7 @@ impl LighthouseClient { | |
| /// heartbeat_timeout_ms (int): The timeout for heartbeats. | ||
| #[pyclass] | ||
| struct LighthouseServer { | ||
| lighthouse: Arc<lighthouse::Lighthouse>, | ||
| bind: String, | ||
| handle: JoinHandle<Result<()>>, | ||
| _runtime: Runtime, | ||
| } | ||
|
|
@@ -631,19 +665,30 @@ impl LighthouseServer { | |
| .enable_all() | ||
| .build()?; | ||
|
|
||
| let lighthouse = rt | ||
| .block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt { | ||
| bind: bind, | ||
| min_replicas: min_replicas, | ||
| join_timeout_ms: join_timeout_ms, | ||
| quorum_tick_ms: quorum_tick_ms, | ||
| heartbeat_timeout_ms: heartbeat_timeout_ms, | ||
| })) | ||
| .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; | ||
| let opt = lighthouse::LighthouseOpt { | ||
| bind: bind.clone(), | ||
| min_replicas, | ||
| join_timeout_ms, | ||
| quorum_tick_ms, | ||
| heartbeat_timeout_ms, | ||
| }; | ||
|
|
||
| let listener = rt.block_on(tokio::net::TcpListener::bind(&bind))?; | ||
| let bound_sock = listener.local_addr()?; | ||
| let bound = format!("http://{}", bound_sock); | ||
| let incoming = TcpListenerStream::new(listener); | ||
|
|
||
| let handle = rt.spawn(async move { | ||
| tonic::transport::Server::builder() | ||
| .add_service(LighthouseServiceServer::new(Router::new(opt.clone()))) | ||
| .serve_with_incoming(incoming) | ||
| .await | ||
| .map_err(|e: tonic::transport::Error| anyhow::anyhow!(e)) | ||
| }); | ||
|
|
||
| Ok(Self { | ||
| handle: rt.spawn(lighthouse.clone().run()), | ||
| lighthouse: lighthouse, | ||
| bind: bound, | ||
| handle, | ||
| _runtime: rt, | ||
| }) | ||
| }) | ||
|
|
@@ -654,7 +699,7 @@ impl LighthouseServer { | |
| /// Returns: | ||
| /// str: The address of the lighthouse server. | ||
| fn address(&self) -> PyResult<String> { | ||
| Ok(self.lighthouse.address().to_string()) | ||
| Ok(self.bind.clone()) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this unfortunately isn't sufficient -- bind could be something like "0.0.0.0:0" which will bind to a random port. Address needs to be the routable http address i.e.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, perhaps we could use similar calls as the Lighthouse class uses to resolve host IP and address? Will include a version of this in next commit, though am also down to change it |
||
| } | ||
|
|
||
| /// shutdown shuts down the lighthouse server. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| use std::sync::Arc; | ||
|
|
||
| use dashmap::{mapref::entry::Entry, DashMap}; | ||
| use tonic::{Request, Response, Status}; | ||
|
|
||
| use crate::{ | ||
| lighthouse::{Lighthouse, LighthouseOpt}, | ||
| torchftpb::{ | ||
| lighthouse_service_server::LighthouseService, LighthouseHeartbeatRequest, | ||
| LighthouseHeartbeatResponse, LighthouseQuorumRequest, LighthouseQuorumResponse, | ||
| }, | ||
| }; | ||
|
|
||
| /// Metadata header for both client and router | ||
| pub const ROOM_ID_HEADER: &str = "room-id"; | ||
|
|
||
| /// Top-level service registered with tonic’s `Server::builder()` | ||
| #[derive(Clone)] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does Router need to be Cloneable?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mainly made Router Cloneable so that calls to tonic's add_service would compile when constructing the LighthouseServer in src/bin/lighthouse.rs and src/lib.rs |
||
| pub struct Router { | ||
| rooms: Arc<DashMap<String, Arc<Lighthouse>>>, | ||
| tmpl_opt: LighthouseOpt, // (cloned for each new room) | ||
| } | ||
|
|
||
| /// Designates a single tonic gRPC server into many logical “rooms.” | ||
| /// Inspects the `room-id` metadata header on each request, then | ||
| /// lazily creates or reuses an Arc<Lighthouse> for that namespace | ||
| impl Router { | ||
| /// Create a new router given the CLI/config options that are | ||
| /// normally passed straight to `Lighthouse::new`. | ||
| pub fn new(tmpl_opt: LighthouseOpt) -> Self { | ||
| Self { | ||
| rooms: Arc::new(DashMap::new()), | ||
| tmpl_opt, | ||
| } | ||
| } | ||
|
|
||
| /// Room lookup: creation if it doesn't exist, access if it does | ||
| async fn room(&self, id: &str) -> Arc<Lighthouse> { | ||
| // 1. Quick optimistic read (no locking contention). | ||
| if let Some(handle) = self.rooms.get(id) { | ||
| return handle.clone(); | ||
| } | ||
|
|
||
| // 2. Build the Lighthouse instance *off the map* so | ||
| // we don't hold any guard across `.await`. | ||
| let new_room = Lighthouse::new(self.tmpl_opt.clone()) | ||
| .await | ||
| .expect("failed to create Lighthouse"); | ||
|
|
||
| // 3. Second pass: insert if still vacant, otherwise reuse | ||
| // whatever another task inserted first. | ||
| match self.rooms.entry(id.to_owned()) { | ||
| Entry::Occupied(entry) => entry.get().clone(), | ||
| Entry::Vacant(entry) => { | ||
| entry.insert(new_room.clone()); | ||
| new_room | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /// Extracts `"room-id"` from metadata, defaulting to `"default"`. | ||
| fn extract_room_id(meta: &tonic::metadata::MetadataMap) -> &str { | ||
| meta.get(ROOM_ID_HEADER) | ||
| .and_then(|v| v.to_str().ok()) | ||
| .unwrap_or("default") | ||
| } | ||
| } | ||
|
|
||
| #[tonic::async_trait] | ||
| impl LighthouseService for Router { | ||
| async fn quorum( | ||
| &self, | ||
| req: Request<LighthouseQuorumRequest>, | ||
| ) -> Result<Response<LighthouseQuorumResponse>, Status> { | ||
| let id = Self::extract_room_id(req.metadata()).to_owned(); | ||
| let room = self.room(&id).await; | ||
| <Arc<Lighthouse> as LighthouseService>::quorum(&room, req).await | ||
| } | ||
|
|
||
| async fn heartbeat( | ||
| &self, | ||
| req: Request<LighthouseHeartbeatRequest>, | ||
| ) -> Result<Response<LighthouseHeartbeatResponse>, Status> { | ||
| let id = Self::extract_room_id(req.metadata()).to_owned(); | ||
| let room = self.room(&id).await; | ||
| <Arc<Lighthouse> as LighthouseService>::heartbeat(&room, req).await | ||
| } | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is fine as is since this is fairly minimal boilerplate per request but I think we can do even better. By doing this at the Service layer instead of LighthouseService layer we can have it automatically work for all endpoints on the LighthouseService Can you look into this and see how feasible it is? If it's not any cleaner we can land this as is Some pointers:
There's also https://github.com/teimuraz/tonic-middleware which might be useful
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried an initial attempt to do the routing at the Service layer rather than the LighthouseService layer, but have had trouble adapting between the initial tonic message types ( If I were to keep at this, I'd see if I could get something working that relies more on
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mixing the two is a bit tricky -- we probably need to stay at the tower layer. Why do you need to access the tonic::Request/Response objects? It's all HTTP at the end of the day so seems like we should be able to operate at the tower/http layer and view the metadata as a header? middleware might work though it may be too high level
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see, it became easier when I had router.rs operate entirely at the tower layer rather than trying to mix Service and tonic. Most recent commit has router.rs at the tower level, which lets us start the lighthouse server with a call to |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| """ | ||
| Validate that one Lighthouse server can host isolated quorums | ||
| for multiple logical rooms (job IDs) via `room-id` metadata header. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
||
|
|
||
| import datetime as _dt | ||
|
||
|
|
||
| import pytest | ||
|
|
||
| import torchft._torchft as ext | ||
|
||
|
|
||
| _TIMEOUT = _dt.timedelta(seconds=3) # connect + RPC timeout | ||
|
|
||
|
|
||
| def _client(addr: str, room: str) -> ext.LighthouseClient: | ||
| """Utility: create a client with a logical room-id.""" | ||
| return ext.LighthouseClient(addr, _TIMEOUT, room) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_multi_room_quorums() -> None: | ||
| # 1) one server, any free port | ||
| server = ext.LighthouseServer("[::]:0", 1) | ||
| addr = server.address() | ||
|
|
||
| # 2) two clients in two separate rooms | ||
| a = _client(addr, "jobA") | ||
| b = _client(addr, "jobB") | ||
|
|
||
| # 3) explicit heartbeats (exercises RPC path) | ||
| a.heartbeat("a0") | ||
| b.heartbeat("b0") | ||
|
|
||
| # 4) ask for a quorum from each room | ||
| qa = a.quorum("a0", _TIMEOUT) | ||
| qb = b.quorum("b0", _TIMEOUT) | ||
|
|
||
| # 5) verify the rooms are independent | ||
| assert qa.quorum_id == qb.quorum_id == 1 | ||
| assert len(qa.participants) == 1 and qa.participants[0].replica_id == "a0" | ||
| assert len(qb.participants) == 1 and qb.participants[0].replica_id == "b0" | ||
|
|
||
| # 6) shutdown | ||
| server.shutdown() | ||
Uh oh!
There was an error while loading. Please reload this page.