From b382755f6ed766f996b516cc61bc204ef98a8699 Mon Sep 17 00:00:00 2001 From: gilpluralis Date: Tue, 14 Oct 2025 05:39:59 +0000 Subject: [PATCH] Limit number of rpc_download_state that can be served to one A peer is allowed to serve it's state without an upper-bound on the number of requests. If many requests happen within a short time frame this peer will go OOM. This will limit the peer to serve only one state at a time. --- hivemind/averaging/averager.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/hivemind/averaging/averager.py b/hivemind/averaging/averager.py index 0db6129d8..3c6bb839c 100644 --- a/hivemind/averaging/averager.py +++ b/hivemind/averaging/averager.py @@ -177,6 +177,7 @@ def __init__( self.shutdown_timeout = shutdown_timeout self.next_chunk_timeout = next_chunk_timeout self.bandwidth = bandwidth + self._state_download_lock = asyncio.Lock() self.matchmaking_kwargs = dict( servicer_type=type(self), @@ -637,18 +638,21 @@ async def rpc_download_state( """ if not self.allow_state_sharing: return # deny request and direct peer to the next prospective averager - metadata, tensors, infos = await self._get_current_state_from_host_process() - if infos is None: - infos = [CompressionInfo.from_tensor(tensor, key=i) for i, tensor in enumerate(tensors)] - assert len(tensors) == len(infos) - - for tensor, info in zip(tensors, infos): - for part in split_for_streaming(self.state_compression.compress(tensor, info, allow_inplace=False)): - if metadata is not None: - yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata) - metadata = None - else: - yield averaging_pb2.DownloadData(tensor_part=part) + if self._state_download_lock.locked(): + return # decline if already serving another peer + async with self._state_download_lock: + metadata, tensors, infos = await self._get_current_state_from_host_process() + if infos is None: + infos = [CompressionInfo.from_tensor(tensor, key=i) for i, tensor in enumerate(tensors)] + assert len(tensors) == len(infos) + + for tensor, info in zip(tensors, infos): + for part in split_for_streaming(self.state_compression.compress(tensor, info, allow_inplace=False)): + if metadata is not None: + yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata) + metadata = None + else: + yield averaging_pb2.DownloadData(tensor_part=part) def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor], Sequence[CompressionInfo]]: """