Skip to content

Commit 4bec540

Browse files
committed
add partial messages to gossipsub router
1 parent 195c009 commit 4bec540

File tree

8 files changed

+687
-29
lines changed

8 files changed

+687
-29
lines changed

comm.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,14 @@ func (p *PubSub) getHelloPacket() *RPC {
3232
}
3333

3434
for t := range subscriptions {
35+
var requestPartial bool
36+
if ts, ok := p.myTopics[t]; ok {
37+
requestPartial = ts.requestPartialMessages
38+
}
3539
as := &pb.RPC_SubOpts{
3640
Topicid: proto.String(t),
3741
Subscribe: proto.Bool(true),
42+
Partial: &requestPartial,
3843
}
3944
rpc.Subscriptions = append(rpc.Subscriptions, as)
4045
}

extensions.go

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
package pubsub
22

33
import (
4+
"errors"
5+
"iter"
6+
7+
"github.com/libp2p/go-libp2p-pubsub/partialmessages"
48
pubsub_pb "github.com/libp2p/go-libp2p-pubsub/pb"
59
"github.com/libp2p/go-libp2p/core/peer"
610
)
711

812
type PeerExtensions struct {
9-
TestExtension bool
13+
TestExtension bool
14+
PartialMessages bool
1015
}
1116

1217
type TestExtensionConfig struct {
@@ -37,6 +42,7 @@ func peerExtensionsFromRPC(rpc *RPC) PeerExtensions {
3742
out := PeerExtensions{}
3843
if hasPeerExtensions(rpc) {
3944
out.TestExtension = rpc.Control.Extensions.GetTestExtension()
45+
out.PartialMessages = rpc.Control.Extensions.GetPartialMessages()
4046
}
4147
return out
4248
}
@@ -46,9 +52,19 @@ func (pe *PeerExtensions) ExtendRPC(rpc *RPC) *RPC {
4652
if rpc.Control == nil {
4753
rpc.Control = &pubsub_pb.ControlMessage{}
4854
}
49-
rpc.Control.Extensions = &pubsub_pb.ControlExtensions{
50-
TestExtension: &pe.TestExtension,
55+
if rpc.Control.Extensions == nil {
56+
rpc.Control.Extensions = &pubsub_pb.ControlExtensions{}
57+
}
58+
rpc.Control.Extensions.TestExtension = &pe.TestExtension
59+
}
60+
if pe.PartialMessages {
61+
if rpc.Control == nil {
62+
rpc.Control = &pubsub_pb.ControlMessage{}
63+
}
64+
if rpc.Control.Extensions == nil {
65+
rpc.Control.Extensions = &pubsub_pb.ControlExtensions{}
5166
}
67+
rpc.Control.Extensions.PartialMessages = &pe.PartialMessages
5268
}
5369
return rpc
5470
}
@@ -59,8 +75,9 @@ type extensionsState struct {
5975
sentExtensions map[peer.ID]struct{}
6076
reportMisbehavior func(peer.ID)
6177
sendRPC func(p peer.ID, r *RPC, urgent bool)
78+
testExtension *testExtension
6279

63-
testExtension *testExtension
80+
partialMessagesExtension *partialmessages.PartialMessageExtension
6481
}
6582

6683
func newExtensionsState(myExtensions PeerExtensions, reportMisbehavior func(peer.ID), sendRPC func(peer.ID, *RPC, bool)) *extensionsState {
@@ -132,14 +149,79 @@ func (es *extensionsState) extensionsAddPeer(id peer.ID) {
132149
if es.myExtensions.TestExtension && es.peerExtensions[id].TestExtension {
133150
es.testExtension.AddPeer(id)
134151
}
152+
153+
if es.myExtensions.PartialMessages && es.peerExtensions[id].PartialMessages {
154+
es.partialMessagesExtension.AddPeer(id)
155+
}
135156
}
136157

137158
// extensionsRemovePeer is always called after extensionsAddPeer.
138159
func (es *extensionsState) extensionsRemovePeer(id peer.ID) {
160+
if es.myExtensions.PartialMessages && es.peerExtensions[id].PartialMessages {
161+
es.partialMessagesExtension.RemovePeer(id)
162+
}
139163
}
140164

141165
func (es *extensionsState) extensionsHandleRPC(rpc *RPC) {
142166
if es.myExtensions.TestExtension && es.peerExtensions[rpc.from].TestExtension {
143167
es.testExtension.HandleRPC(rpc.from, rpc.TestExtension)
144168
}
169+
170+
if es.myExtensions.PartialMessages && es.peerExtensions[rpc.from].PartialMessages && rpc.Partial != nil {
171+
es.partialMessagesExtension.HandleRPC(rpc.from, rpc.Partial)
172+
}
173+
}
174+
175+
func (es *extensionsState) Heartbeat() {
176+
if es.myExtensions.PartialMessages {
177+
es.partialMessagesExtension.Heartbeat()
178+
}
179+
}
180+
181+
func WithPartialMessagesExtension(pm *partialmessages.PartialMessageExtension) Option {
182+
return func(ps *PubSub) error {
183+
gs, ok := ps.rt.(*GossipSubRouter)
184+
if !ok {
185+
return errors.New("pubsub router is not gossipsub")
186+
}
187+
err := pm.Init(partialMessageRouter{gs})
188+
if err != nil {
189+
return err
190+
}
191+
192+
gs.extensions.myExtensions.PartialMessages = true
193+
gs.extensions.partialMessagesExtension = pm
194+
return nil
195+
}
196+
}
197+
198+
type partialMessageRouter struct {
199+
gs *GossipSubRouter
200+
}
201+
202+
// MeshPeers implements partialmessages.Router.
203+
func (r partialMessageRouter) MeshPeers(topic string) iter.Seq[peer.ID] {
204+
return func(yield func(peer.ID) bool) {
205+
for peer := range r.gs.mesh[topic] {
206+
if exts := r.gs.extensions.peerExtensions[peer]; exts.PartialMessages {
207+
if peerStates, ok := r.gs.p.topics[topic]; ok && peerStates[peer].requestsPartial {
208+
// Check that the peer wanted partial messages
209+
if !yield(peer) {
210+
return
211+
}
212+
}
213+
}
214+
}
215+
}
216+
}
217+
218+
// SendRPC implements partialmessages.Router.
219+
func (r partialMessageRouter) SendRPC(p peer.ID, rpc *pubsub_pb.PartialMessagesExtension, urgent bool) {
220+
r.gs.sendRPC(p, &RPC{
221+
RPC: pubsub_pb.RPC{
222+
Partial: rpc,
223+
},
224+
}, urgent)
145225
}
226+
227+
var _ partialmessages.Router = partialMessageRouter{}

gossipsub.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,6 +1375,10 @@ func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] {
13751375
if pid == from || pid == peer.ID(msg.GetFrom()) {
13761376
continue
13771377
}
1378+
if peerStates, ok := gs.p.topics[topic]; ok && peerStates[pid].requestsPartial {
1379+
// The peer requested partial messages. We'll skip sending them full messages
1380+
continue
1381+
}
13781382

13791383
if !yield(pid, out) {
13801384
return
@@ -1833,6 +1837,8 @@ func (gs *GossipSubRouter) heartbeat() {
18331837

18341838
// advance the message history window
18351839
gs.mcache.Shift()
1840+
1841+
gs.extensions.Heartbeat()
18361842
}
18371843

18381844
func (gs *GossipSubRouter) clearIHaveCounters() {

0 commit comments

Comments
 (0)