Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions comm.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,14 @@ func (p *PubSub) getHelloPacket() *RPC {
}

for t := range subscriptions {
var requestPartial bool
if ts, ok := p.myTopics[t]; ok {
requestPartial = ts.requestPartialMessages
}
as := &pb.RPC_SubOpts{
Topicid: proto.String(t),
Subscribe: proto.Bool(true),
Partial: &requestPartial,
}
rpc.Subscriptions = append(rpc.Subscriptions, as)
}
Expand Down Expand Up @@ -123,7 +128,7 @@ func (p *PubSub) notifyPeerDead(pid peer.ID) {
}

func (p *PubSub) handleNewPeer(ctx context.Context, pid peer.ID, outgoing *rpcQueue) {
s, err := p.host.NewStream(p.ctx, pid, p.rt.Protocols()...)
s, err := p.host.NewStream(ctx, pid, p.rt.Protocols()...)
if err != nil {
p.logger.Debug("error opening new stream to peer", "err", err, "peer", pid)

Expand All @@ -135,11 +140,14 @@ func (p *PubSub) handleNewPeer(ctx context.Context, pid peer.ID, outgoing *rpcQu
return
}

go p.handleSendingMessages(ctx, s, outgoing)
firstMessage := make(chan *RPC, 1)
sCtx, cancel := context.WithCancel(ctx)
go p.handleSendingMessages(sCtx, s, outgoing, firstMessage)
go p.handlePeerDead(s)
select {
case p.newPeerStream <- s:
case p.newPeerStream <- peerOutgoingStream{Stream: s, FirstMessage: firstMessage, Cancel: cancel}:
case <-ctx.Done():
cancel()
}
}

Expand All @@ -164,7 +172,7 @@ func (p *PubSub) handlePeerDead(s network.Stream) {
p.notifyPeerDead(pid)
}

func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, outgoing *rpcQueue) {
func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, outgoing *rpcQueue, firstMessage chan *RPC) {
writeRpc := func(rpc *RPC) error {
size := uint64(rpc.Size())

Expand All @@ -177,6 +185,11 @@ func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, ou
return err
}

if err := s.SetWriteDeadline(time.Now().Add(time.Second * 30)); err != nil {
p.rpcLogger.Debug("failed to set write deadline", "peer", s.Conn().RemotePeer(), "err", err)
return err
}

_, err = s.Write(buf)
if err != nil {
p.rpcLogger.Debug("failed to send message", "peer", s.Conn().RemotePeer(), "rpc", rpc, "err", err)
Expand All @@ -186,6 +199,21 @@ func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, ou
return nil
}

select {
case rpc := <-firstMessage:
if rpc.Size() > 0 {
err := writeRpc(rpc)
if err != nil {
s.Reset()
p.logger.Debug("error writing message to peer", "peer", s.Conn().RemotePeer(), "err", err)
return
}
}
case <-ctx.Done():
s.Reset()
return
}

defer s.Close()
for ctx.Err() == nil {
rpc, err := outgoing.Pop(ctx)
Expand Down
99 changes: 95 additions & 4 deletions extensions.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package pubsub

import (
"errors"
"iter"

"github.com/libp2p/go-libp2p-pubsub/partialmessages"
pubsub_pb "github.com/libp2p/go-libp2p-pubsub/pb"
"github.com/libp2p/go-libp2p/core/peer"
)

type PeerExtensions struct {
TestExtension bool
TestExtension bool
PartialMessages bool
}

type TestExtensionConfig struct {
Expand Down Expand Up @@ -37,6 +42,7 @@ func peerExtensionsFromRPC(rpc *RPC) PeerExtensions {
out := PeerExtensions{}
if hasPeerExtensions(rpc) {
out.TestExtension = rpc.Control.Extensions.GetTestExtension()
out.PartialMessages = rpc.Control.Extensions.GetPartialMessages()
}
return out
}
Expand All @@ -46,9 +52,19 @@ func (pe *PeerExtensions) ExtendRPC(rpc *RPC) *RPC {
if rpc.Control == nil {
rpc.Control = &pubsub_pb.ControlMessage{}
}
rpc.Control.Extensions = &pubsub_pb.ControlExtensions{
TestExtension: &pe.TestExtension,
if rpc.Control.Extensions == nil {
rpc.Control.Extensions = &pubsub_pb.ControlExtensions{}
}
rpc.Control.Extensions.TestExtension = &pe.TestExtension
}
if pe.PartialMessages {
if rpc.Control == nil {
rpc.Control = &pubsub_pb.ControlMessage{}
}
if rpc.Control.Extensions == nil {
rpc.Control.Extensions = &pubsub_pb.ControlExtensions{}
}
rpc.Control.Extensions.PartialMessages = &pe.PartialMessages
}
return rpc
}
Expand All @@ -59,8 +75,9 @@ type extensionsState struct {
sentExtensions map[peer.ID]struct{}
reportMisbehavior func(peer.ID)
sendRPC func(p peer.ID, r *RPC, urgent bool)
testExtension *testExtension

testExtension *testExtension
partialMessagesExtension *partialmessages.PartialMessageExtension
}

func newExtensionsState(myExtensions PeerExtensions, reportMisbehavior func(peer.ID), sendRPC func(peer.ID, *RPC, bool)) *extensionsState {
Expand Down Expand Up @@ -132,14 +149,88 @@ func (es *extensionsState) extensionsAddPeer(id peer.ID) {
if es.myExtensions.TestExtension && es.peerExtensions[id].TestExtension {
es.testExtension.AddPeer(id)
}

if es.myExtensions.PartialMessages && es.peerExtensions[id].PartialMessages {
es.partialMessagesExtension.AddPeer(id)
}
}

// extensionsRemovePeer is always called after extensionsAddPeer.
func (es *extensionsState) extensionsRemovePeer(id peer.ID) {
if es.myExtensions.PartialMessages && es.peerExtensions[id].PartialMessages {
es.partialMessagesExtension.RemovePeer(id)
}
}

func (es *extensionsState) extensionsHandleRPC(rpc *RPC) {
if es.myExtensions.TestExtension && es.peerExtensions[rpc.from].TestExtension {
es.testExtension.HandleRPC(rpc.from, rpc.TestExtension)
}

if es.myExtensions.PartialMessages && es.peerExtensions[rpc.from].PartialMessages && rpc.Partial != nil {
es.partialMessagesExtension.HandleRPC(rpc.from, rpc.Partial)
}
}

func (es *extensionsState) Heartbeat() {
if es.myExtensions.PartialMessages {
es.partialMessagesExtension.Heartbeat()
}
}

func WithPartialMessagesExtension(pm *partialmessages.PartialMessageExtension) Option {
return func(ps *PubSub) error {
gs, ok := ps.rt.(*GossipSubRouter)
if !ok {
return errors.New("pubsub router is not gossipsub")
}
err := pm.Init(partialMessageRouter{gs})
if err != nil {
return err
}

gs.extensions.myExtensions.PartialMessages = true
gs.extensions.partialMessagesExtension = pm
return nil
}
}

type partialMessageRouter struct {
gs *GossipSubRouter
}

// MeshPeers implements partialmessages.Router.
func (r partialMessageRouter) MeshPeers(topic string) iter.Seq[peer.ID] {
return func(yield func(peer.ID) bool) {
peerSet, ok := r.gs.mesh[topic]
if !ok {
// Possible a fanout topic
peerSet, ok = r.gs.fanout[topic]
if !ok {
return
}
}

for peer := range peerSet {
if exts := r.gs.extensions.peerExtensions[peer]; exts.PartialMessages {
if peerStates, ok := r.gs.p.topics[topic]; ok && peerStates[peer].requestsPartial {
// Check that the peer wanted partial messages
if !yield(peer) {
return
}
}
}
}
}
}

// SendRPC implements partialmessages.Router.
func (r partialMessageRouter) SendRPC(p peer.ID, rpc *pubsub_pb.PartialMessagesExtension, urgent bool) {
r.gs.sendRPC(p, &RPC{
RPC: pubsub_pb.RPC{
Partial: rpc,
},
}, urgent)
}

var _ partialmessages.Router = partialMessageRouter{}
17 changes: 17 additions & 0 deletions gossipsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,12 @@ func (gs *GossipSubRouter) Preprocess(from peer.ID, msgs []*Message) {
// We don't send IDONTWANT to the peer that sent us the messages
continue
}
if gs.peerWantsPartial(p, topic) {
// Don't send IDONTWANT to peers that are using partial messages
// for this topic
continue
}

// send to only peers that support IDONTWANT
if gs.feature(GossipSubFeatureIdontwant, gs.peers[p]) {
idontwant := []*pb.ControlIDontWant{{MessageIDs: mids}}
Expand Down Expand Up @@ -1375,6 +1381,10 @@ func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] {
if pid == from || pid == peer.ID(msg.GetFrom()) {
continue
}
if gs.peerWantsPartial(pid, topic) {
// The peer requested partial messages. We'll skip sending them full messages
continue
}

if !yield(pid, out) {
return
Expand All @@ -1383,6 +1393,11 @@ func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] {
}
}

func (gs *GossipSubRouter) peerWantsPartial(p peer.ID, topic string) bool {
peerStates, ok := gs.p.topics[topic]
return ok && gs.extensions.myExtensions.PartialMessages && peerStates[p].requestsPartial
}

func (gs *GossipSubRouter) Join(topic string) {
gmap, ok := gs.mesh[topic]
if ok {
Expand Down Expand Up @@ -1833,6 +1848,8 @@ func (gs *GossipSubRouter) heartbeat() {

// advance the message history window
gs.mcache.Shift()

gs.extensions.Heartbeat()
}

func (gs *GossipSubRouter) clearIHaveCounters() {
Expand Down
Loading