diff --git a/abbrevTree.go b/abbrevTree.go new file mode 100644 index 0000000..d66f9e2 --- /dev/null +++ b/abbrevTree.go @@ -0,0 +1,117 @@ +package multistream + +import ( + "crypto/sha256" +) + +type nodeProtocol[T StringLike] struct { + protocolID T + tombstoneBit bool +} + +type abbrevTree[T StringLike] struct { + root *abbrevNode[T] +} + +type abbrevNode[T StringLike] struct { + p *nodeProtocol[T] + children [256]*abbrevNode[T] +} + +func (at *abbrevTree[T]) Abbreviate(pid T) []byte { + var result []byte + hash := sha256.Sum256([]byte(pid)) + + if at.root == nil { + return nil + } + + current := at.root + // go furthest in the tree + for _, b := range hash { + if current.children[b] != nil { + result = append(result, b) + current = current.children[b] + } else { + break + } + } + + if current.p != nil && current.p.protocolID == pid && !current.p.tombstoneBit { + return result + } + return nil +} + +func (at *abbrevTree[T]) GetProtocolID(prefix []byte) (T, error) { + if at.root == nil { + return "", ErrUnknownPrefix + } + current := at.root + for _, b := range prefix { + if current.children[b] == nil { + return "", ErrUnknownPrefix + } + current = current.children[b] + } + if current.p == nil { + return "", ErrUnknownPrefix + } + return current.p.protocolID, nil +} + +func (at *abbrevTree[T]) AddProtocol(pid T) { + hash := sha256.Sum256([]byte(pid)) + + if at.root == nil { + at.root = &abbrevNode[T]{} + } + + current := at.root + for idx, b := range hash { + if current.children[b] == nil { + current.children[b] = &abbrevNode[T]{ + p: &nodeProtocol[T]{ + protocolID: pid, + tombstoneBit: false, + }, + } + return + } + current = current.children[b] + + if current.p != nil { + if current.p.protocolID == pid { + // Resurrect the protocol ID. + current.p.tombstoneBit = false + } else if !current.p.tombstoneBit { + // There is another protocol in this node, so we need to duplicate it down. + h := sha256.Sum256([]byte(current.p.protocolID)) + + if current.children[h[idx+1]] == nil { + // It should be fine to reference the same nodeProtocol instance. + current.children[h[idx+1]] = &abbrevNode[T]{p: current.p} + } + } + } + } +} + +func (at *abbrevTree[T]) RemoveProtocol(pid T) { + hash := sha256.Sum256([]byte(pid)) + + if at.root == nil { + return + } + current := at.root + for _, b := range hash { + if current.children[b] == nil { + break + } + current = current.children[b] + + if current.p.protocolID == pid { + current.p.tombstoneBit = true + } + } +} diff --git a/abbrevTree_test.go b/abbrevTree_test.go new file mode 100644 index 0000000..cd435e7 --- /dev/null +++ b/abbrevTree_test.go @@ -0,0 +1,155 @@ +package multistream + +import ( + "bytes" + "crypto/sha256" + "testing" +) + +func TestAbbrevTreeAddProtocol(t *testing.T) { + tree := &abbrevTree[string]{} + + proto1 := "protocol1" + hash1 := sha256.Sum256([]byte(proto1)) + proto2 := "protocol2" + hash2 := sha256.Sum256([]byte(proto2)) + proto3 := "protocol251" // this one has the same first byte as "protocol1" + hash3 := sha256.Sum256([]byte(proto3)) + + // make sure we don't make mistakes on the hashes + if hash1[0] == hash2[0] { + t.Fatal("the first bytes of hash1 and hash2 should be different") + } + if hash1[0] != hash3[0] { + t.Fatal("the first bytes of hash1 and hash3 should be the same") + } + if hash1[1] == hash3[1] { + t.Fatal("the second bytes of hash1 and hash3 should be different") + } + + // add only proto1 + tree.AddProtocol(proto1) + + if tree.root == nil { + t.Fatal("root should not be nil after adding protocol") + } + if tree.root.children[hash1[0]] == nil || tree.root.children[hash1[0]].p == nil { + t.Fatal("the protocol was not added") + } + if tree.root.children[hash1[0]].p.protocolID != proto1 { + t.Fatal("the protocol ID was wrong") + } + if tree.root.children[hash1[0]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must not be set") + } + if !bytes.Equal(tree.Abbreviate(proto1), []byte{hash1[0]}) { + t.Fatal("abbreviation of proto1 is incorrect") + } + + // also add proto2 + tree.AddProtocol(proto2) + + if tree.root.children[hash2[0]] == nil || tree.root.children[hash2[0]].p == nil { + t.Fatal("the protocol was not added") + } + if tree.root.children[hash2[0]].p.protocolID != proto2 { + t.Fatal("the protocol ID was wrong") + } + if tree.root.children[hash2[0]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto2 must not be set") + } + if !bytes.Equal(tree.Abbreviate(proto2), []byte{hash2[0]}) { + t.Fatal("abbreviation of proto2 is incorrect") + } + + // add proto3 which has the same first byte of the hash as proto1 + tree.AddProtocol(proto3) + + n1 := tree.root.children[hash1[0]] + // the node at the first level should still be proto1 + if n1.p.protocolID != proto1 { + t.Fatal("the node in the first level should not be modified") + } + // proto1 should be duplicated down + if n1.children[hash1[1]] == nil || n1.children[hash1[1]].p == nil { + t.Fatal("proto1 was not duplicated") + } + if n1.children[hash1[1]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must not be set") + } + if !bytes.Equal(tree.Abbreviate(proto1), []byte{hash1[0], hash1[1]}) { + t.Fatal("abbreviation of proto1 is incorrect") + } + // proto3 should be added in the second level + if n1.children[hash3[1]] == nil || n1.children[hash3[1]].p == nil { + t.Fatal("proto3 was not added") + } + if n1.children[hash3[1]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto3 must not be set") + } + if !bytes.Equal(tree.Abbreviate(proto3), []byte{hash3[0], hash3[1]}) { + t.Fatal("abbreviation of proto3 is incorrect") + } +} + +func TestAbbrevTreeRemoveProtocol(t *testing.T) { + tree := &abbrevTree[string]{} + + proto1 := "protocol1" + hash1 := sha256.Sum256([]byte(proto1)) + proto2 := "protocol2" + proto3 := "protocol251" // this one has the same first byte as "protocol1" + + tree.AddProtocol(proto1) + tree.AddProtocol(proto2) + tree.AddProtocol(proto3) + + // remove only proto1 + tree.RemoveProtocol(proto1) + + n1 := tree.root.children[hash1[0]] + if !n1.p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must be set") + } + if !n1.children[hash1[1]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must be set") + } + if tree.Abbreviate(proto1) != nil { + t.Fatal("abbreviation of proto1 should be nil") + } +} + +func TestAbbrevTreeResurrectProtocol(t *testing.T) { + tree := &abbrevTree[string]{} + + proto1 := "protocol1" + hash1 := sha256.Sum256([]byte(proto1)) + proto2 := "protocol2" + proto3 := "protocol251" // this one has the same first byte as "protocol1" + + tree.AddProtocol(proto1) + tree.AddProtocol(proto2) + tree.AddProtocol(proto3) + tree.RemoveProtocol(proto1) + tree.AddProtocol(proto1) + + n1 := tree.root.children[hash1[0]] + if n1.p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must not be set") + } + if n1.children[hash1[1]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must not be set") + } + + // There should be another leaf node added for proto1 + n2 := n1.children[hash1[1]] + if n2.children[hash1[2]] == nil || n2.children[hash1[2]].p == nil { + t.Fatal("proto1 was not added") + } + if n2.children[hash1[2]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must not be set") + } + if !bytes.Equal(tree.Abbreviate(proto1), []byte{hash1[0], hash1[1], hash1[2]}) { + t.Fatal("abbreviation of proto1 is incorrect") + } +} diff --git a/lazyClient.go b/lazyClient.go index 3ff48f9..a6ec50d 100644 --- a/lazyClient.go +++ b/lazyClient.go @@ -1,6 +1,7 @@ package multistream import ( + "bytes" "fmt" "io" ) @@ -9,7 +10,7 @@ import ( // protocol selection with a MultistreamMuxer. func NewMSSelect[T StringLike](c io.ReadWriteCloser, proto T) LazyConn { return &lazyClientConn[T]{ - protos: []T{ProtocolID, proto}, + protos: []protoInfo[T]{{ID: ProtocolID}, {ID: proto}}, con: c, rhandshakeOnce: newOnce(), @@ -17,12 +18,31 @@ func NewMSSelect[T StringLike](c io.ReadWriteCloser, proto T) LazyConn { } } +func NewMSSelect2[T StringLike](c io.ReadWriteCloser, proto T, peerProtos []T) LazyConn { + t := &abbrevTree[T]{} + for _, p := range peerProtos { + t.AddProtocol(p) + } + + abbrv := t.Abbreviate(proto) + return &lazyClientConn[T]{ + protos: []protoInfo[T]{ + {ID: ProtocolID, Abbrev: ProtocolAbbrev}, + {ID: proto, Abbrev: abbrv}, + }, + con: c, + + rhandshakeOnce: newOnce(), + whandshakeOnce: newOnce(), + } +} + // NewMultistream returns a multistream for the given protocol. This will not // perform any protocol selection. If you are using a MultistreamMuxer, use // NewMSSelect. func NewMultistream[T StringLike](c io.ReadWriteCloser, proto T) LazyConn { return &lazyClientConn[T]{ - protos: []T{proto}, + protos: []protoInfo[T]{{ID: proto}}, con: c, rhandshakeOnce: newOnce(), @@ -58,6 +78,11 @@ func (o *once) Do(f func()) { f() } +type protoInfo[T StringLike] struct { + ID T + Abbrev []byte +} + // lazyClientConn is a ReadWriteCloser adapter that lazily negotiates a protocol // using multistream-select on first use. // @@ -74,7 +99,7 @@ type lazyClientConn[T StringLike] struct { werr error // The sequence of protocols to negotiate. - protos []T + protos []protoInfo[T] // The inner connection. con io.ReadWriteCloser @@ -104,18 +129,22 @@ func (l *lazyClientConn[T]) Read(b []byte) (int, error) { func (l *lazyClientConn[T]) doReadHandshake() { for _, proto := range l.protos { // read protocol - tok, err := ReadNextToken[T](l.con) + tok, err := ReadNextTokenBytes(l.con) if err != nil { l.rerr = err return } - if tok == "na" { - l.rerr = ErrNotSupported[T]{[]T{proto}} + if bytes.Equal(tok, []byte("na")) { + l.rerr = ErrNotSupported[T]{[]T{proto.ID}} return } - if tok != proto { - l.rerr = fmt.Errorf("protocol mismatch in lazy handshake ( %s != %s )", tok, proto) + if proto.Abbrev != nil && !bytes.Equal(tok, proto.Abbrev) { + l.rerr = fmt.Errorf("protocol mismatch in lazy handshake ( %x != %x )", tok, proto.Abbrev) + return + } + if proto.Abbrev == nil && T(tok) != proto.ID { + l.rerr = fmt.Errorf("protocol mismatch in lazy handshake ( %s != %s )", T(tok), proto.ID) return } } @@ -131,7 +160,11 @@ func (l *lazyClientConn[T]) doWriteHandshakeWithData(extra []byte) int { defer putWriter(buf) for _, proto := range l.protos { - l.werr = delimWrite(buf, []byte(proto)) + if proto.Abbrev != nil { + l.werr = delimWrite(buf, proto.Abbrev) + } else { + l.werr = delimWrite(buf, []byte(proto.ID)) + } if l.werr != nil { return 0 } diff --git a/multistream.go b/multistream.go index 17e1ef7..1c059d7 100644 --- a/multistream.go +++ b/multistream.go @@ -5,6 +5,7 @@ package multistream import ( "bufio" + "bytes" "errors" "fmt" "io" @@ -18,10 +19,19 @@ import ( // ErrTooLarge is an error to signal that an incoming message was too large var ErrTooLarge = errors.New("incoming message was too large") +// ErrUnknownPrefix is an error to signal that the protocol hash prefix is unknown +var ErrUnknownPrefix = errors.New("unknown protocol hash prefix") + // ProtocolID identifies the multistream protocol itself and makes sure // the multistream muxers on both sides of a channel can work with each other. const ProtocolID = "/multistream/1.0.0" +// ProtocolID identifies the multistream protocol abbreviation support +var ProtocolAbbrev = []byte{0xff, 0x11} + +// Multistream-select version that protocol abbreviation is supported +const AbbrevSupportedMSSVersion = 2 + var writerPool = sync.Pool{ New: func() interface{} { return bufio.NewWriter(nil) @@ -52,6 +62,7 @@ type Handler[T StringLike] struct { type MultistreamMuxer[T StringLike] struct { handlerlock sync.RWMutex handlers []Handler[T] + abbrevTree abbrevTree[T] } // NewMultistreamMuxer creates a muxer. @@ -134,6 +145,7 @@ func (msm *MultistreamMuxer[T]) AddHandlerWithFunc(protocol T, match func(T) boo msm.handlerlock.Lock() defer msm.handlerlock.Unlock() + msm.abbrevTree.AddProtocol(protocol) msm.removeHandler(protocol) msm.handlers = append(msm.handlers, Handler[T]{ MatchFunc: match, @@ -147,6 +159,7 @@ func (msm *MultistreamMuxer[T]) RemoveHandler(protocol T) { msm.handlerlock.Lock() defer msm.handlerlock.Unlock() + msm.abbrevTree.RemoveProtocol(protocol) msm.removeHandler(protocol) } @@ -199,17 +212,21 @@ func (msm *MultistreamMuxer[T]) Negotiate(rwc io.ReadWriteCloser) (proto T, hand } }() - // Send the multistream protocol ID - // Ignore the error here. We want the handshake to finish, even if the - // other side has closed this rwc for writing. They may have sent us a - // message and closed. Future writers will get an error anyways. - _ = delimWriteBuffered(rwc, []byte(ProtocolID)) - line, err := ReadNextToken[T](rwc) + token, err := ReadNextTokenBytes(rwc) if err != nil { return "", nil, err } - - if line != ProtocolID { + supportAbbrev := false + // Send the multistream protocol ID or the mulstream protocol abbreviation + // Ignore the error here. We want the handshake to finish, even if the + // other side has closed this rwc for writing. They may have sent us a + // message and closed. Future writers will get an error anyways. + if bytes.Equal(token, ProtocolAbbrev) { + supportAbbrev = true + _ = delimWriteBuffered(rwc, ProtocolAbbrev) + } else if T(token) == ProtocolID { + _ = delimWriteBuffered(rwc, []byte(ProtocolID)) + } else { rwc.Close() return "", nil, ErrIncorrectVersion } @@ -217,12 +234,27 @@ func (msm *MultistreamMuxer[T]) Negotiate(rwc io.ReadWriteCloser) (proto T, hand loop: for { // Now read and respond to commands until they send a valid protocol id - tok, err := ReadNextToken[T](rwc) + var proto T + + tok, err := ReadNextTokenBytes(rwc) if err != nil { return "", nil, err } - h := msm.findHandler(tok) + if supportAbbrev { + // decode the protocol abbreviation using the abbreviation tree + msm.handlerlock.RLock() + proto, err = msm.abbrevTree.GetProtocolID(tok) + msm.handlerlock.RUnlock() + + if err != nil { + return "", nil, err + } + } else { + proto = T(tok) + } + + h := msm.findHandler(proto) if h == nil { if err := delimWriteBuffered(rwc, []byte("na")); err != nil { return "", nil, err @@ -233,10 +265,10 @@ loop: // Ignore the error here. We want the handshake to finish, even if the // other side has closed this rwc for writing. They may have sent us a // message and closed. Future writers will get an error anyways. - _ = delimWriteBuffered(rwc, []byte(tok)) + _ = delimWriteBuffered(rwc, tok) // hand off processing to the sub-protocol handler - return tok, h.Handle, nil + return proto, h.Handle, nil } }