diff --git a/protocols/kad/src/behaviour.rs b/protocols/kad/src/behaviour.rs index f5f44baec74..c641c61cb8d 100644 --- a/protocols/kad/src/behaviour.rs +++ b/protocols/kad/src/behaviour.rs @@ -2297,6 +2297,18 @@ where self.connection_updated(source, address, NodeStatus::Connected); } + HandlerEvent::AddProviderSent { query_id } => { + if let Some(query) = self.queries.get_mut(&query_id) { + if let QueryInfo::AddProvider { + phase: AddProviderPhase::AddProvider { .. }, + .. + } = &query.info + { + query.on_success(&source, vec![]); + } + } + } + HandlerEvent::ProtocolNotSupported { endpoint } => { let address = match endpoint { ConnectedPoint::Dialer { address, .. } => Some(address), @@ -2636,31 +2648,12 @@ where } QueryPoolState::Waiting(Some((query, peer_id))) => { let event = query.info.to_request(query.id()); - // TODO: AddProvider requests yield no response, so the query completes - // as soon as all requests have been sent. However, the handler should - // better emit an event when the request has been sent (and report - // an error if sending fails), instead of immediately reporting - // "success" somewhat prematurely here. - if let QueryInfo::AddProvider { - phase: AddProviderPhase::AddProvider { .. }, - .. - } = &query.info - { - query.on_success(&peer_id, vec![]) - } - - if self.connected_peers.contains(&peer_id) { - self.queued_events.push_back(ToSwarm::NotifyHandler { - peer_id, - event, - handler: NotifyHandler::Any, - }); - } else if &peer_id != self.kbuckets.local_key().preimage() { - query.pending_rpcs.push((peer_id, event)); - self.queued_events.push_back(ToSwarm::Dial { - opts: DialOpts::peer_id(peer_id).build(), - }); - } + let event = ToSwarm::NotifyHandler { + peer_id, + handler: NotifyHandler::Any, + event, + }; + return Poll::Ready(event); } QueryPoolState::Waiting(None) | QueryPoolState::Idle => break, } diff --git a/protocols/kad/src/handler.rs b/protocols/kad/src/handler.rs index 2c7b6c52257..3a7a4edc1be 100644 --- a/protocols/kad/src/handler.rs +++ b/protocols/kad/src/handler.rs @@ -277,6 +277,11 @@ pub enum HandlerEvent { /// The user data passed to the `PutValue`. query_id: QueryId, }, + /// Notification that a one-way request (e.g., AddProvider) has been sent. + AddProviderSent { + /// The user data passed to the request. + query_id: QueryId, + }, } /// Error that can happen when requesting an RPC query. @@ -716,68 +721,69 @@ impl ConnectionHandler for Handler { &mut self, cx: &mut Context<'_>, ) -> Poll> { - loop { - match &mut self.protocol_status { - Some(status) if !status.reported => { - status.reported = true; - let event = if status.supported { - HandlerEvent::ProtocolConfirmed { - endpoint: self.endpoint.clone(), - } - } else { - HandlerEvent::ProtocolNotSupported { - endpoint: self.endpoint.clone(), - } - }; + match &mut self.protocol_status { + Some(status) if !status.reported => { + status.reported = true; + let event = if status.supported { + HandlerEvent::ProtocolConfirmed { + endpoint: self.endpoint.clone(), + } + } else { + HandlerEvent::ProtocolNotSupported { + endpoint: self.endpoint.clone(), + } + }; - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); - } - _ => {} + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); } + _ => {} + } - match self.outbound_substreams.poll_unpin(cx) { - Poll::Ready((Ok(Ok(Some(response))), query_id)) => { - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( - process_kad_response(response, query_id), - )) - } - Poll::Ready((Ok(Ok(None)), _)) => { - continue; - } - Poll::Ready((Ok(Err(e)), query_id)) => { - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( - HandlerEvent::QueryError { - error: HandlerQueryErr::Io(e), - query_id, - }, - )) - } - Poll::Ready((Err(_timeout), query_id)) => { - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( - HandlerEvent::QueryError { - error: HandlerQueryErr::Io(io::ErrorKind::TimedOut.into()), - query_id, - }, - )) - } - Poll::Pending => {} + match self.outbound_substreams.poll_unpin(cx) { + Poll::Ready((Ok(Ok(Some(response))), query_id)) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + process_kad_response(response, query_id), + )) } - - if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) { - return Poll::Ready(event); + Poll::Ready((Ok(Ok(None)), query_id)) => { + // One-way request successfully sent (e.g., AddProvider). + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + HandlerEvent::AddProviderSent { query_id }, + )); } - - if self.outbound_substreams.len() < MAX_NUM_STREAMS { - if let Some((msg, id)) = self.pending_messages.pop_front() { - self.queue_new_stream(id, msg); - return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(self.protocol_config.clone(), ()), - }); - } + Poll::Ready((Ok(Err(e)), query_id)) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + HandlerEvent::QueryError { + error: HandlerQueryErr::Io(e), + query_id, + }, + )) } + Poll::Ready((Err(_timeout), query_id)) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + HandlerEvent::QueryError { + error: HandlerQueryErr::Io(io::ErrorKind::TimedOut.into()), + query_id, + }, + )) + } + Poll::Pending => {} + } - return Poll::Pending; + if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) { + return Poll::Ready(event); } + + if self.outbound_substreams.len() < MAX_NUM_STREAMS { + if let Some((msg, id)) = self.pending_messages.pop_front() { + self.queue_new_stream(id, msg); + return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(self.protocol_config.clone(), ()), + }); + } + } + + return Poll::Pending; } fn on_connection_event(