11package pubsub
22
33import (
4+ "errors"
5+ "iter"
6+
7+ "github.com/libp2p/go-libp2p-pubsub/partialmessages"
48 pubsub_pb "github.com/libp2p/go-libp2p-pubsub/pb"
59 "github.com/libp2p/go-libp2p/core/peer"
610)
711
812type PeerExtensions struct {
9- TestExtension bool
13+ TestExtension bool
14+ PartialMessages bool
1015}
1116
1217type TestExtensionConfig struct {
@@ -37,6 +42,7 @@ func peerExtensionsFromRPC(rpc *RPC) PeerExtensions {
3742 out := PeerExtensions {}
3843 if hasPeerExtensions (rpc ) {
3944 out .TestExtension = rpc .Control .Extensions .GetTestExtension ()
45+ out .PartialMessages = rpc .Control .Extensions .GetPartialMessages ()
4046 }
4147 return out
4248}
@@ -46,9 +52,19 @@ func (pe *PeerExtensions) ExtendRPC(rpc *RPC) *RPC {
4652 if rpc .Control == nil {
4753 rpc .Control = & pubsub_pb.ControlMessage {}
4854 }
49- rpc .Control .Extensions = & pubsub_pb.ControlExtensions {
50- TestExtension : & pe .TestExtension ,
55+ if rpc .Control .Extensions == nil {
56+ rpc .Control .Extensions = & pubsub_pb.ControlExtensions {}
57+ }
58+ rpc .Control .Extensions .TestExtension = & pe .TestExtension
59+ }
60+ if pe .PartialMessages {
61+ if rpc .Control == nil {
62+ rpc .Control = & pubsub_pb.ControlMessage {}
63+ }
64+ if rpc .Control .Extensions == nil {
65+ rpc .Control .Extensions = & pubsub_pb.ControlExtensions {}
5166 }
67+ rpc .Control .Extensions .PartialMessages = & pe .PartialMessages
5268 }
5369 return rpc
5470}
@@ -59,8 +75,9 @@ type extensionsState struct {
5975 sentExtensions map [peer.ID ]struct {}
6076 reportMisbehavior func (peer.ID )
6177 sendRPC func (p peer.ID , r * RPC , urgent bool )
78+ testExtension * testExtension
6279
63- testExtension * testExtension
80+ partialMessagesExtension * partialmessages. PartialMessageExtension
6481}
6582
6683func newExtensionsState (myExtensions PeerExtensions , reportMisbehavior func (peer.ID ), sendRPC func (peer.ID , * RPC , bool )) * extensionsState {
@@ -132,14 +149,79 @@ func (es *extensionsState) extensionsAddPeer(id peer.ID) {
132149 if es .myExtensions .TestExtension && es .peerExtensions [id ].TestExtension {
133150 es .testExtension .AddPeer (id )
134151 }
152+
153+ if es .myExtensions .PartialMessages && es .peerExtensions [id ].PartialMessages {
154+ es .partialMessagesExtension .AddPeer (id )
155+ }
135156}
136157
137158// extensionsRemovePeer is always called after extensionsAddPeer.
138159func (es * extensionsState ) extensionsRemovePeer (id peer.ID ) {
160+ if es .myExtensions .PartialMessages && es .peerExtensions [id ].PartialMessages {
161+ es .partialMessagesExtension .RemovePeer (id )
162+ }
139163}
140164
141165func (es * extensionsState ) extensionsHandleRPC (rpc * RPC ) {
142166 if es .myExtensions .TestExtension && es .peerExtensions [rpc .from ].TestExtension {
143167 es .testExtension .HandleRPC (rpc .from , rpc .TestExtension )
144168 }
169+
170+ if es .myExtensions .PartialMessages && es .peerExtensions [rpc .from ].PartialMessages && rpc .Partial != nil {
171+ es .partialMessagesExtension .HandleRPC (rpc .from , rpc .Partial )
172+ }
173+ }
174+
175+ func (es * extensionsState ) Heartbeat () {
176+ if es .myExtensions .PartialMessages {
177+ es .partialMessagesExtension .Heartbeat ()
178+ }
179+ }
180+
181+ func WithPartialMessagesExtension (pm * partialmessages.PartialMessageExtension ) Option {
182+ return func (ps * PubSub ) error {
183+ gs , ok := ps .rt .(* GossipSubRouter )
184+ if ! ok {
185+ return errors .New ("pubsub router is not gossipsub" )
186+ }
187+ err := pm .Init (partialMessageRouter {gs })
188+ if err != nil {
189+ return err
190+ }
191+
192+ gs .extensions .myExtensions .PartialMessages = true
193+ gs .extensions .partialMessagesExtension = pm
194+ return nil
195+ }
196+ }
197+
198+ type partialMessageRouter struct {
199+ gs * GossipSubRouter
200+ }
201+
202+ // MeshPeers implements partialmessages.Router.
203+ func (r partialMessageRouter ) MeshPeers (topic string ) iter.Seq [peer.ID ] {
204+ return func (yield func (peer.ID ) bool ) {
205+ for peer := range r .gs .mesh [topic ] {
206+ if exts := r .gs .extensions .peerExtensions [peer ]; exts .PartialMessages {
207+ if peerStates , ok := r .gs .p .topics [topic ]; ok && peerStates [peer ].requestsPartial {
208+ // Check that the peer wanted partial messages
209+ if ! yield (peer ) {
210+ return
211+ }
212+ }
213+ }
214+ }
215+ }
216+ }
217+
218+ // SendRPC implements partialmessages.Router.
219+ func (r partialMessageRouter ) SendRPC (p peer.ID , rpc * pubsub_pb.PartialMessagesExtension , urgent bool ) {
220+ r .gs .sendRPC (p , & RPC {
221+ RPC : pubsub_pb.RPC {
222+ Partial : rpc ,
223+ },
224+ }, urgent )
145225}
226+
227+ var _ partialmessages.Router = partialMessageRouter {}
0 commit comments