diff --git a/config/config.go b/config/config.go index 8e730b19ee..5854bc3ff5 100644 --- a/config/config.go +++ b/config/config.go @@ -28,8 +28,10 @@ import ( bhost "github.com/libp2p/go-libp2p/p2p/host/basic" blankhost "github.com/libp2p/go-libp2p/p2p/host/blank" "github.com/libp2p/go-libp2p/p2p/host/eventbus" + "github.com/libp2p/go-libp2p/p2p/host/natmanager" "github.com/libp2p/go-libp2p/p2p/host/observedaddrs" "github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem" + "github.com/libp2p/go-libp2p/p2p/host/pstoremanager" rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" routed "github.com/libp2p/go-libp2p/p2p/host/routed" "github.com/libp2p/go-libp2p/p2p/net/swarm" @@ -39,6 +41,7 @@ import ( relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/protocol/identify" + "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/security/insecure" "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" @@ -117,9 +120,10 @@ type Config struct { ConnManager connmgr.ConnManager ResourceManager network.ResourceManager - NATManager NATManagerC - Peerstore peerstore.Peerstore - Reporter metrics.Reporter + EnableNATPortMap bool + NATManager bhost.NATManager + Peerstore peerstore.Peerstore + Reporter metrics.Reporter MultiaddrResolver network.MultiaddrDNSResolver @@ -446,13 +450,10 @@ func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus, an *auton ConnManager: cfg.ConnManager, AddrsFactory: cfg.AddrsFactory, NATManager: cfg.NATManager, - EnablePing: !cfg.DisablePing, UserAgent: cfg.UserAgent, ProtocolVersion: cfg.ProtocolVersion, EnableHolePunching: cfg.EnableHolePunching, HolePunchingOptions: cfg.HolePunchingOptions, - EnableRelayService: cfg.EnableRelayService, - RelayServiceOpts: cfg.RelayServiceOpts, EnableMetrics: !cfg.DisableMetrics, PrometheusRegisterer: cfg.PrometheusRegisterer, AutoNATv2: an, @@ -554,6 +555,17 @@ func (cfg *Config) NewNode() (host.Host, error) { }) return o, nil }), + fx.Provide(func(s *swarm.Swarm, lifecycle fx.Lifecycle) (bhost.NATManager, error) { + if !cfg.EnableNATPortMap && cfg.NATManager == nil { + return nil, nil + } + if cfg.NATManager != nil { + return cfg.NATManager, nil + } + nm := natmanager.New(s) + lifecycle.Append(fx.StartStopHook(nm.Start, nm.Close)) + return nm, nil + }), fx.Provide(func() (*autonatv2.AutoNAT, error) { if !cfg.EnableAutoNATv2 { return nil, nil @@ -580,6 +592,13 @@ func (cfg *Config) NewNode() (host.Host, error) { return bh }), fx.Provide(func(h *swarm.Swarm) peer.ID { return h.LocalPeer() }), + fx.Provide(func(h host.Host) (*pstoremanager.PeerstoreManager, error) { + psm, err := pstoremanager.NewPeerstoreManager(h.Peerstore(), h.EventBus(), h.Network()) + if err != nil { + return nil, fmt.Errorf("failed to create PeerstoreManager: %w", err) + } + return psm, nil + }), } transportOpts, err := cfg.addTransports() if err != nil { @@ -597,6 +616,11 @@ func (cfg *Config) NewNode() (host.Host, error) { ) } + fxopts = append(fxopts, fx.Invoke(func(psm *pstoremanager.PeerstoreManager, lifecycle fx.Lifecycle) error { + lifecycle.Append(fx.StartStopHook(psm.Start, psm.Close)) + return nil + })) + // enable autorelay fxopts = append(fxopts, fx.Invoke(func(h *bhost.BasicHost, lifecycle fx.Lifecycle) error { @@ -619,6 +643,30 @@ func (cfg *Config) NewNode() (host.Host, error) { }), ) + if !cfg.DisablePing { + fxopts = append(fxopts, fx.Invoke(func(h *bhost.BasicHost) { + ping.NewPingService(h) + })) + } + + if cfg.EnableRelayService { + fxopts = append(fxopts, fx.Invoke(func(h host.Host, lifecycle fx.Lifecycle) error { + if !cfg.DisableMetrics { + // Prefer explicitly provided metrics tracer + metricsOpt := []relayv2.Option{ + relayv2.WithMetricsTracer( + relayv2.NewMetricsTracer(relayv2.WithRegisterer(cfg.PrometheusRegisterer)))} + cfg.RelayServiceOpts = append(metricsOpt, cfg.RelayServiceOpts...) + } + rs, err := relayv2.New(h, cfg.RelayServiceOpts...) + if err != nil { + return err + } + lifecycle.Append(fx.StartStopHook(rs.Start, rs.Close)) + return nil + })) + } + var bh *bhost.BasicHost fxopts = append(fxopts, fx.Invoke(func(bho *bhost.BasicHost) { bh = bho })) fxopts = append(fxopts, fx.Invoke(func(h *bhost.BasicHost, lifecycle fx.Lifecycle) { diff --git a/libp2p_test.go b/libp2p_test.go index c35ce8a9ba..8fbdf37cce 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -29,6 +29,7 @@ import ( "github.com/libp2p/go-libp2p/core/pnet" "github.com/libp2p/go-libp2p/core/routing" "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/host/pstoremanager" rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" "github.com/libp2p/go-libp2p/p2p/net/swarm" "github.com/libp2p/go-libp2p/p2p/protocol/ping" @@ -44,6 +45,7 @@ import ( "github.com/pion/webrtc/v4" quicgo "github.com/quic-go/quic-go" wtgo "github.com/quic-go/webtransport-go" + "go.uber.org/fx" "go.uber.org/goleak" ma "github.com/multiformats/go-multiaddr" @@ -917,3 +919,47 @@ func TestConnAs(t *testing.T) { }) } } + +func TestPeerstoreManager(t *testing.T) { + ctx := t.Context() + + var psm1, psm2 *pstoremanager.PeerstoreManager + // Create two hosts + h1, err := New(WithFxOption(fx.Populate(&psm1))) + require.NoError(t, err) + defer h1.Close() + require.NotNil(t, psm1) + + h2, err := New(WithFxOption(fx.Populate(&psm2))) + require.NoError(t, err) + require.NotNil(t, psm2) + defer h2.Close() + + // Set stream handlers to establish protocols on each host + h1.SetStreamHandler("/test/protocol/1.0.0", func(s network.Stream) { + s.Close() + }) + h2.SetStreamHandler("/test/protocol/2.0.0", func(s network.Stream) { + s.Close() + }) + + // Connect the two hosts + err = h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) + require.NoError(t, err) + + // Disconnect the hosts + err = h1.Network().ClosePeer(h2.ID()) + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + // Check that h1 has h2's protocol info in its peerstore + h2Protocols, err := h1.Peerstore().GetProtocols(h2.ID()) + require.NoError(t, err) + require.NotEmpty(t, h2Protocols, "h1 should have h2's protocol info after disconnect") + + // Check that h2 has h1's protocol info in its peerstore + h1Protocols, err := h2.Peerstore().GetProtocols(h1.ID()) + require.NoError(t, err) + require.NotEmpty(t, h1Protocols, "h2 should have h1's protocol info after disconnect") +} diff --git a/options.go b/options.go index 0329b7e60b..b8f2589e48 100644 --- a/options.go +++ b/options.go @@ -418,13 +418,22 @@ func ResourceManager(rcmgr network.ResourceManager) Option { // NATPortMap configures libp2p to use the default NATManager. The default // NATManager will attempt to open a port in your network's firewall using UPnP. func NATPortMap() Option { - return NATManager(bhost.NewNATManager) + return func(cfg *Config) error { + if cfg.NATManager != nil { + return fmt.Errorf("cannot enable both NATManager and NATPortMap") + } + cfg.EnableNATPortMap = true + return nil + } } // NATManager will configure libp2p to use the requested NATManager. This // function should be passed a NATManager *constructor* that takes a libp2p Network. -func NATManager(nm config.NATManagerC) Option { +func NATManager(nm bhost.NATManager) Option { return func(cfg *Config) error { + if cfg.EnableNATPortMap { + return fmt.Errorf("cannot enable both NATManager and NATPortMap") + } if cfg.NATManager != nil { return fmt.Errorf("cannot specify multiple NATManagers") } diff --git a/p2p/host/basic/addrs_manager.go b/p2p/host/basic/addrs_manager.go index 503ee9b08b..57e97d19a9 100644 --- a/p2p/host/basic/addrs_manager.go +++ b/p2p/host/basic/addrs_manager.go @@ -43,6 +43,16 @@ type ObservedAddrsManager interface { AddrsFor(local ma.Multiaddr) []ma.Multiaddr } +// NATManager is a simple interface to manage NAT devices. +// It listens Listen and ListenClose notifications from the network.Network, +// and tries to obtain port mappings for those. +type NATManager interface { + GetMapping(ma.Multiaddr) ma.Multiaddr + HasDiscoveredNAT() bool + Start() + io.Closer +} + type hostAddrs struct { addrs []ma.Multiaddr localAddrs []ma.Multiaddr @@ -156,12 +166,6 @@ func (a *addrsManager) Start() error { func (a *addrsManager) Close() { a.ctxCancel() - if a.natManager != nil { - err := a.natManager.Close() - if err != nil { - log.Warn("error closing natmgr", "err", err) - } - } if a.addrsReachabilityTracker != nil { err := a.addrsReachabilityTracker.Close() if err != nil { diff --git a/p2p/host/basic/addrs_manager_test.go b/p2p/host/basic/addrs_manager_test.go index 6183327896..121d373b98 100644 --- a/p2p/host/basic/addrs_manager_test.go +++ b/p2p/host/basic/addrs_manager_test.go @@ -43,6 +43,8 @@ func (*mockNatManager) HasDiscoveredNAT() bool { return true } +func (*mockNatManager) Start() {} + var _ NATManager = &mockNatManager{} type mockObservedAddrs struct { diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index d98aa0e337..1d11cce681 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -18,13 +18,9 @@ import ( "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/p2p/host/autonat" "github.com/libp2p/go-libp2p/p2p/host/eventbus" - "github.com/libp2p/go-libp2p/p2p/host/pstoremanager" - "github.com/libp2p/go-libp2p/p2p/host/relaysvc" "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" - relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/protocol/identify" - "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/prometheus/client_golang/prometheus" logging "github.com/libp2p/go-libp2p/gologshim" @@ -59,15 +55,12 @@ type BasicHost struct { // keep track of resources we need to wait on before shutting down refCount sync.WaitGroup - network network.Network - psManager *pstoremanager.PeerstoreManager - mux *msmux.MultistreamMuxer[protocol.ID] - ids identify.IDService - hps *holepunch.Service - pings *ping.PingService - cmgr connmgr.ConnManager - eventbus event.Bus - relayManager *relaysvc.RelayManager + network network.Network + mux *msmux.MultistreamMuxer[protocol.ID] + ids identify.IDService + hps *holepunch.Service + cmgr connmgr.ConnManager + eventbus event.Bus negtimeout time.Duration @@ -105,19 +98,11 @@ type HostOpts struct { // NATManager takes care of setting NAT port mappings, and discovering external addresses. // If omitted, this will simply be disabled. - NATManager func(network.Network) NATManager + NATManager NATManager // ConnManager is a libp2p connection manager ConnManager connmgr.ConnManager - // EnablePing indicates whether to instantiate the ping service - EnablePing bool - - // EnableRelayService enables the circuit v2 relay (if we're publicly reachable). - EnableRelayService bool - // RelayServiceOpts are options for the circuit v2 relay. - RelayServiceOpts []relayv2.Option - // UserAgent sets the user-agent for the host. UserAgent string @@ -136,8 +121,6 @@ type HostOpts struct { EnableMetrics bool // PrometheusRegisterer is the PrometheusRegisterer used for metrics PrometheusRegisterer prometheus.Registerer - // AutoNATv2MetricsTracker tracks AutoNATv2 address reachability metrics - AutoNATv2MetricsTracker MetricsTracker // ObservedAddrsManager maps our local listen addresses to external publicly observed addresses. ObservedAddrsManager ObservedAddrsManager @@ -154,15 +137,9 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { opts.EventBus = eventbus.NewBus() } - psManager, err := pstoremanager.NewPeerstoreManager(n.Peerstore(), opts.EventBus, n) - if err != nil { - return nil, err - } - hostCtx, cancel := context.WithCancel(context.Background()) h := &BasicHost{ network: n, - psManager: psManager, mux: msmux.NewMultistreamMuxer[protocol.ID](), negtimeout: DefaultNegotiationTimeout, eventbus: opts.EventBus, @@ -170,6 +147,7 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { ctxCancel: cancel, } + var err error if h.emitters.evtLocalProtocolsUpdated, err = h.eventbus.Emitter(&event.EvtLocalProtocolsUpdated{}, eventbus.Stateful); err != nil { return nil, err } @@ -203,11 +181,6 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { addrFactory = opts.AddrsFactory } - var natmgr NATManager - if opts.NATManager != nil { - natmgr = opts.NATManager(h.Network()) - } - if opts.AutoNATv2 != nil { h.autonatv2 = opts.AutoNATv2 } @@ -229,7 +202,7 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { h.addressManager, err = newAddrsManager( h.eventbus, - natmgr, + opts.NATManager, addrFactory, h.Network().ListenAddresses, addCertHashesFunc, @@ -270,21 +243,6 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { n.Notify(h.cmgr.Notifee()) } - if opts.EnableRelayService { - if opts.EnableMetrics { - // Prefer explicitly provided metrics tracer - metricsOpt := []relayv2.Option{ - relayv2.WithMetricsTracer( - relayv2.NewMetricsTracer(relayv2.WithRegisterer(opts.PrometheusRegisterer)))} - opts.RelayServiceOpts = append(metricsOpt, opts.RelayServiceOpts...) - } - h.relayManager = relaysvc.NewRelayManager(h, opts.RelayServiceOpts...) - } - - if opts.EnablePing { - h.pings = ping.NewPingService(h) - } - n.SetStreamHandler(h.newStreamHandler) return h, nil @@ -293,7 +251,6 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { // Start starts background tasks in the host // TODO: Return error and handle it in the caller? func (h *BasicHost) Start() { - h.psManager.Start() if h.autonatv2 != nil { err := h.autonatv2.Start(h) if err != nil { @@ -637,9 +594,6 @@ func (h *BasicHost) Close() error { if h.autoNat != nil { h.autoNat.Close() } - if h.relayManager != nil { - h.relayManager.Close() - } if h.hps != nil { h.hps.Close() } @@ -654,7 +608,6 @@ func (h *BasicHost) Close() error { } h.addressManager.Close() - h.psManager.Close() if h.Peerstore() != nil { h.Peerstore().Close() } diff --git a/p2p/host/basic/mocks.go b/p2p/host/basic/mocks.go deleted file mode 100644 index a29a0c5ef7..0000000000 --- a/p2p/host/basic/mocks.go +++ /dev/null @@ -1,6 +0,0 @@ -//go:build gomock || generate - -package basichost - -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package basichost -destination mock_nat_test.go github.com/libp2p/go-libp2p/p2p/host/basic NAT" -type NAT nat diff --git a/p2p/host/basic/mock_nat_test.go b/p2p/host/natmanager/mock_nat_test.go similarity index 90% rename from p2p/host/basic/mock_nat_test.go rename to p2p/host/natmanager/mock_nat_test.go index 924e52c566..5c08331302 100644 --- a/p2p/host/basic/mock_nat_test.go +++ b/p2p/host/natmanager/mock_nat_test.go @@ -1,13 +1,13 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/libp2p/go-libp2p/p2p/host/basic (interfaces: NAT) +// Source: github.com/libp2p/go-libp2p/p2p/host/natmanager (interfaces: NAT) // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package basichost -destination mock_nat_test.go github.com/libp2p/go-libp2p/p2p/host/basic NAT +// mockgen -build_flags=-tags=gomock -package natmanager -destination mock_nat_test.go github.com/libp2p/go-libp2p/p2p/host/natmanager NAT // -// Package basichost is a generated GoMock package. -package basichost +// Package natmanager is a generated GoMock package. +package natmanager import ( context "context" diff --git a/p2p/host/natmanager/mocks.go b/p2p/host/natmanager/mocks.go new file mode 100644 index 0000000000..b9e2f56ed8 --- /dev/null +++ b/p2p/host/natmanager/mocks.go @@ -0,0 +1,6 @@ +//go:build gomock || generate + +package natmanager + +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package natmanager -destination mock_nat_test.go github.com/libp2p/go-libp2p/p2p/host/natmanager NAT" +type NAT nat diff --git a/p2p/host/basic/natmgr.go b/p2p/host/natmanager/natmgr.go similarity index 86% rename from p2p/host/basic/natmgr.go rename to p2p/host/natmanager/natmgr.go index 0deb407f85..f3ccece20f 100644 --- a/p2p/host/basic/natmgr.go +++ b/p2p/host/natmanager/natmgr.go @@ -1,4 +1,4 @@ -package basichost +package natmanager import ( "context" @@ -9,25 +9,15 @@ import ( "sync" "github.com/libp2p/go-libp2p/core/network" + basichost "github.com/libp2p/go-libp2p/p2p/host/basic" inat "github.com/libp2p/go-libp2p/p2p/net/nat" + logging "github.com/libp2p/go-libp2p/gologshim" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" ) -// NATManager is a simple interface to manage NAT devices. -// It listens Listen and ListenClose notifications from the network.Network, -// and tries to obtain port mappings for those. -type NATManager interface { - GetMapping(ma.Multiaddr) ma.Multiaddr - HasDiscoveredNAT() bool - io.Closer -} - -// NewNATManager creates a NAT manager. -func NewNATManager(net network.Network) NATManager { - return newNATManager(net) -} +var log = logging.Logger("natmanager") type entry struct { protocol string @@ -50,7 +40,7 @@ var discoverNAT = func(ctx context.Context) (nat, error) { return inat.DiscoverN // - natManager listens to the network and adds or closes port mappings // as the network signals Listen() or ListenClose(). // - closing the natManager closes the nat and its mappings. -type natManager struct { +type NATManager struct { net network.Network natMx sync.RWMutex nat nat @@ -64,41 +54,45 @@ type natManager struct { ctxCancel context.CancelFunc } -func newNATManager(net network.Network) *natManager { +var _ basichost.NATManager = (*NATManager)(nil) + +func New(net network.Network) *NATManager { ctx, cancel := context.WithCancel(context.Background()) - nmgr := &natManager{ + nmgr := &NATManager{ net: net, syncFlag: make(chan struct{}, 1), ctx: ctx, ctxCancel: cancel, tracked: make(map[entry]bool), } - nmgr.refCount.Add(1) - go nmgr.background(ctx) return nmgr } +func (nmgr *NATManager) Start() { + nmgr.refCount.Add(1) + go nmgr.background(nmgr.ctx) +} + // Close closes the natManager, closing the underlying nat // and unregistering from network events. -func (nmgr *natManager) Close() error { +func (nmgr *NATManager) Close() error { nmgr.ctxCancel() nmgr.refCount.Wait() return nil } -func (nmgr *natManager) HasDiscoveredNAT() bool { +func (nmgr *NATManager) HasDiscoveredNAT() bool { nmgr.natMx.RLock() defer nmgr.natMx.RUnlock() return nmgr.nat != nil } -func (nmgr *natManager) background(ctx context.Context) { +func (nmgr *NATManager) background(ctx context.Context) { defer nmgr.refCount.Done() defer func() { nmgr.natMx.Lock() defer nmgr.natMx.Unlock() - if nmgr.nat != nil { nmgr.nat.Close() } @@ -133,7 +127,7 @@ func (nmgr *natManager) background(ctx context.Context) { } } -func (nmgr *natManager) sync() { +func (nmgr *NATManager) sync() { select { case nmgr.syncFlag <- struct{}{}: default: @@ -142,7 +136,7 @@ func (nmgr *natManager) sync() { // doSync syncs the current NAT mappings, removing any outdated mappings and adding any // new mappings. -func (nmgr *natManager) doSync() { +func (nmgr *NATManager) doSync() { for e := range nmgr.tracked { nmgr.tracked[e] = false } @@ -214,7 +208,7 @@ func (nmgr *natManager) doSync() { } } -func (nmgr *natManager) GetMapping(addr ma.Multiaddr) ma.Multiaddr { +func (nmgr *NATManager) GetMapping(addr ma.Multiaddr) ma.Multiaddr { nmgr.natMx.Lock() defer nmgr.natMx.Unlock() @@ -289,9 +283,9 @@ func (nmgr *natManager) GetMapping(addr ma.Multiaddr) ma.Multiaddr { return extMaddr } -type nmgrNetNotifiee natManager +type nmgrNetNotifiee NATManager -func (nn *nmgrNetNotifiee) natManager() *natManager { return (*natManager)(nn) } +func (nn *nmgrNetNotifiee) natManager() *NATManager { return (*NATManager)(nn) } func (nn *nmgrNetNotifiee) Listen(network.Network, ma.Multiaddr) { nn.natManager().sync() } func (nn *nmgrNetNotifiee) ListenClose(_ network.Network, _ ma.Multiaddr) { nn.natManager().sync() } func (nn *nmgrNetNotifiee) Connected(network.Network, network.Conn) {} diff --git a/p2p/host/basic/natmgr_test.go b/p2p/host/natmanager/natmgr_test.go similarity index 98% rename from p2p/host/basic/natmgr_test.go rename to p2p/host/natmanager/natmgr_test.go index be2567dd99..8946876d18 100644 --- a/p2p/host/basic/natmgr_test.go +++ b/p2p/host/natmanager/natmgr_test.go @@ -1,4 +1,4 @@ -package basichost +package natmanager import ( "context" @@ -33,7 +33,8 @@ func TestMapping(t *testing.T) { sw := swarmt.GenSwarm(t) defer sw.Close() - m := newNATManager(sw) + m := New(sw) + m.Start() require.Eventually(t, func() bool { m.natMx.Lock() defer m.natMx.Unlock() @@ -67,7 +68,8 @@ func TestAddAndRemoveListeners(t *testing.T) { sw := swarmt.GenSwarm(t) defer sw.Close() - m := newNATManager(sw) + m := New(sw) + m.Start() require.Eventually(t, func() bool { m.natMx.Lock() defer m.natMx.Unlock() diff --git a/p2p/host/relaysvc/relay.go b/p2p/host/relaysvc/relay.go deleted file mode 100644 index f9bbc7588e..0000000000 --- a/p2p/host/relaysvc/relay.go +++ /dev/null @@ -1,96 +0,0 @@ -package relaysvc - -import ( - "context" - "sync" - - "github.com/libp2p/go-libp2p/core/event" - "github.com/libp2p/go-libp2p/core/host" - "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/p2p/host/eventbus" - relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" -) - -type RelayManager struct { - host host.Host - - mutex sync.Mutex - relay *relayv2.Relay - opts []relayv2.Option - - refCount sync.WaitGroup - ctxCancel context.CancelFunc -} - -func NewRelayManager(host host.Host, opts ...relayv2.Option) *RelayManager { - ctx, cancel := context.WithCancel(context.Background()) - m := &RelayManager{ - host: host, - opts: opts, - ctxCancel: cancel, - } - m.refCount.Add(1) - go m.background(ctx) - return m -} - -func (m *RelayManager) background(ctx context.Context) { - defer m.refCount.Done() - defer func() { - m.mutex.Lock() - if m.relay != nil { - m.relay.Close() - } - m.mutex.Unlock() - }() - - subReachability, _ := m.host.EventBus().Subscribe(new(event.EvtLocalReachabilityChanged), eventbus.Name("relaysvc")) - defer subReachability.Close() - - for { - select { - case <-ctx.Done(): - return - case ev, ok := <-subReachability.Out(): - if !ok { - return - } - if err := m.reachabilityChanged(ev.(event.EvtLocalReachabilityChanged).Reachability); err != nil { - return - } - } - } -} - -func (m *RelayManager) reachabilityChanged(r network.Reachability) error { - switch r { - case network.ReachabilityPublic: - m.mutex.Lock() - defer m.mutex.Unlock() - // This could happen if two consecutive EvtLocalReachabilityChanged report the same reachability. - // This shouldn't happen, but it's safer to double-check. - if m.relay != nil { - return nil - } - relay, err := relayv2.New(m.host, m.opts...) - if err != nil { - return err - } - m.relay = relay - default: - m.mutex.Lock() - defer m.mutex.Unlock() - if m.relay != nil { - err := m.relay.Close() - m.relay = nil - return err - } - } - return nil -} - -func (m *RelayManager) Close() error { - m.ctxCancel() - m.refCount.Wait() - return nil -} diff --git a/p2p/host/relaysvc/relay_test.go b/p2p/host/relaysvc/relay_test.go deleted file mode 100644 index 83a1784ea2..0000000000 --- a/p2p/host/relaysvc/relay_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package relaysvc - -import ( - "testing" - "time" - - "github.com/libp2p/go-libp2p/core/event" - "github.com/libp2p/go-libp2p/core/network" - bhost "github.com/libp2p/go-libp2p/p2p/host/blank" - "github.com/libp2p/go-libp2p/p2p/host/eventbus" - swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" - relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" - "github.com/stretchr/testify/require" -) - -func TestReachabilityChangeEvent(t *testing.T) { - h := bhost.NewBlankHost(swarmt.GenSwarm(t)) - rmgr := NewRelayManager(h) - emitter, err := rmgr.host.EventBus().Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful) - if err != nil { - t.Fatal(err) - } - evt := event.EvtLocalReachabilityChanged{Reachability: network.ReachabilityPublic} - emitter.Emit(evt) - require.Eventually( - t, - func() bool { rmgr.mutex.Lock(); defer rmgr.mutex.Unlock(); return rmgr.relay != nil }, - 1*time.Second, - 100*time.Millisecond, - "relay should be set on public reachability") - - evt = event.EvtLocalReachabilityChanged{Reachability: network.ReachabilityPrivate} - emitter.Emit(evt) - require.Eventually( - t, - func() bool { rmgr.mutex.Lock(); defer rmgr.mutex.Unlock(); return rmgr.relay == nil }, - 3*time.Second, - 100*time.Millisecond, - "relay should be nil on private reachability") - - evt = event.EvtLocalReachabilityChanged{Reachability: network.ReachabilityPublic} - emitter.Emit(evt) - evt = event.EvtLocalReachabilityChanged{Reachability: network.ReachabilityUnknown} - emitter.Emit(evt) - require.Eventually( - t, - func() bool { rmgr.mutex.Lock(); defer rmgr.mutex.Unlock(); return rmgr.relay == nil }, - 3*time.Second, - 100*time.Millisecond, - "relay should be nil on unknown reachability") - - evt = event.EvtLocalReachabilityChanged{Reachability: network.ReachabilityPublic} - emitter.Emit(evt) - var relay *relayv2.Relay - require.Eventually( - t, - func() bool { rmgr.mutex.Lock(); defer rmgr.mutex.Unlock(); relay = rmgr.relay; return relay != nil }, - 3*time.Second, - 100*time.Millisecond, - "relay should be set on public event") - emitter.Emit(evt) - require.Never(t, - func() bool { rmgr.mutex.Lock(); defer rmgr.mutex.Unlock(); return relay != rmgr.relay }, - 3*time.Second, - 100*time.Millisecond, - "relay should not be updated on receiving the same event") -} diff --git a/p2p/protocol/circuitv2/relay/relay.go b/p2p/protocol/circuitv2/relay/relay.go index e5d79e0cbc..d475410379 100644 --- a/p2p/protocol/circuitv2/relay/relay.go +++ b/p2p/protocol/circuitv2/relay/relay.go @@ -10,10 +10,12 @@ import ( "time" "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/record" + "github.com/libp2p/go-libp2p/p2p/host/eventbus" pbv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/pb" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/proto" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/util" @@ -59,6 +61,7 @@ type Relay struct { rsvp map[peer.ID]time.Time conns map[peer.ID]int closed bool + wg sync.WaitGroup selfAddr ma.Multiaddr @@ -102,32 +105,40 @@ func New(h host.Host, opts ...Option) (*Relay, error) { r.constraints = newConstraints(&r.rc) r.selfAddr = ma.StringCast(fmt.Sprintf("/p2p/%s", h.ID())) - h.SetStreamHandler(proto.ProtoIDv2Hop, r.handleStream) - r.notifiee = &network.NotifyBundle{DisconnectedF: r.disconnected} - h.Network().Notify(r.notifiee) - if r.metricsTracer != nil { r.metricsTracer.RelayStatus(true) } - go r.background() return r, nil } +func (r *Relay) Start() { + sub, err := r.host.EventBus().Subscribe(new(event.EvtLocalReachabilityChanged), eventbus.Name("relay-svc")) + if err != nil { + log.Error("failed to subscribe to reachability event; disabling relay service", "err", err) + return + } + r.notifiee = &network.NotifyBundle{DisconnectedF: r.disconnected} + r.host.Network().Notify(r.notifiee) + r.wg.Add(1) + go r.background(sub) +} + func (r *Relay) Close() error { r.mx.Lock() if !r.closed { + defer r.scope.Done() r.closed = true r.mx.Unlock() - r.host.RemoveStreamHandler(proto.ProtoIDv2Hop) r.host.Network().StopNotify(r.notifiee) - defer r.scope.Done() r.cancel() + r.wg.Wait() r.gc() if r.metricsTracer != nil { r.metricsTracer.RelayStatus(false) } + r.host.RemoveStreamHandler(proto.ProtoIDv2Hop) return nil } r.mx.Unlock() @@ -695,7 +706,9 @@ func (r *Relay) makeLimitMsg(_ peer.ID) *pbv2.Limit { } } -func (r *Relay) background() { +func (r *Relay) background(reachabilitySub event.Subscription) { + defer r.wg.Done() + defer reachabilitySub.Close() ticker := time.NewTicker(time.Minute) defer ticker.Stop() @@ -703,6 +716,17 @@ func (r *Relay) background() { select { case <-ticker.C: r.gc() + case ev, ok := <-reachabilitySub.Out(): + if !ok { + return // subscription close, node's closing? + } + evt := ev.(event.EvtLocalReachabilityChanged) + switch evt.Reachability { + case network.ReachabilityPublic: + r.host.SetStreamHandler(proto.ProtoIDv2Hop, r.handleStream) + default: + r.host.RemoveStreamHandler(proto.ProtoIDv2Hop) + } case <-r.ctx.Done(): return } diff --git a/p2p/protocol/circuitv2/relay/relay_test.go b/p2p/protocol/circuitv2/relay/relay_test.go index 93390678dd..8461e95f49 100644 --- a/p2p/protocol/circuitv2/relay/relay_test.go +++ b/p2p/protocol/circuitv2/relay/relay_test.go @@ -16,6 +16,7 @@ import ( "github.com/libp2p/go-libp2p/core/metrics" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/transport" bhost "github.com/libp2p/go-libp2p/p2p/host/blank" "github.com/libp2p/go-libp2p/p2p/host/eventbus" @@ -23,6 +24,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/net/swarm" swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" + "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/proto" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/transport/tcp" "github.com/stretchr/testify/require" @@ -77,6 +79,12 @@ func getNetHosts(t *testing.T, _ context.Context, n int) (hosts []host.Host, upg h := bhost.NewBlankHost(netw, bhost.WithEventBus(bus)) hosts = append(hosts, h) + emitter, err := bus.Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful) + if err != nil { + t.Fatal(err) + } + // configure the host to be publicly reachable so relay service is started + emitter.Emit(event.EvtLocalReachabilityChanged{Reachability: network.ReachabilityPublic}) } return hosts, upgraders @@ -128,6 +136,7 @@ func TestBasicRelay(t *testing.T) { if err != nil { t.Fatal(err) } + r.Start() defer r.Close() connect(t, hosts[0], hosts[1]) @@ -227,6 +236,7 @@ func TestRelayLimitTime(t *testing.T) { if err != nil { t.Fatal(err) } + r.Start() defer r.Close() connect(t, hosts[0], hosts[1]) @@ -312,6 +322,7 @@ func TestRelayLimitData(t *testing.T) { if err != nil { t.Fatal(err) } + r.Start() defer r.Close() connect(t, hosts[0], hosts[1]) @@ -379,3 +390,49 @@ func TestRelayLimitData(t *testing.T) { } } + +func TestRelayReachabilityEvent(t *testing.T) { + ctx := t.Context() + + hosts, _ := getNetHosts(t, ctx, 1) + h := hosts[0] + defer h.Close() + + // Create a relay service + r, err := relay.New(h) + require.NoError(t, err) + defer r.Close() + r.Start() + + emitter, err := h.EventBus().Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful) + require.NoError(t, err) + defer emitter.Close() + + protocolID := protocol.ID(proto.ProtoIDv2Hop) + time.Sleep(100 * time.Millisecond) + require.Contains(t, h.Mux().Protocols(), protocolID, "stream handler should be set initially") + + // Send reachability event: Private + err = emitter.Emit(event.EvtLocalReachabilityChanged{Reachability: network.ReachabilityPrivate}) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + require.NotContains(t, h.Mux().Protocols(), protocolID, "stream handler should be removed when reachability is Private") + + // Send reachability event: Public + err = emitter.Emit(event.EvtLocalReachabilityChanged{Reachability: network.ReachabilityPublic}) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + require.Contains(t, h.Mux().Protocols(), protocolID, "stream handler should be set when reachability is Public") + + // Send reachability event: Unknown + err = emitter.Emit(event.EvtLocalReachabilityChanged{Reachability: network.ReachabilityUnknown}) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + require.NotContains(t, h.Mux().Protocols(), protocolID, "stream handler should be removed when reachability is Unknown") + + // Send reachability event: Public again + err = emitter.Emit(event.EvtLocalReachabilityChanged{Reachability: network.ReachabilityPublic}) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + require.Contains(t, h.Mux().Protocols(), protocolID, "stream handler should be set when reachability is Public again") +}