Skip to content

Commit e2fcc77

Browse files
committed
ms-select2: compress multiselect proto id
1 parent 6b42268 commit e2fcc77

File tree

2 files changed

+62
-48
lines changed

2 files changed

+62
-48
lines changed

lazyClient.go

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package multistream
22

33
import (
4-
"encoding/hex"
4+
"bytes"
55
"fmt"
66
"io"
77
)
@@ -10,7 +10,7 @@ import (
1010
// protocol selection with a MultistreamMuxer.
1111
func NewMSSelect[T StringLike](c io.ReadWriteCloser, proto T) LazyConn {
1212
return &lazyClientConn[T]{
13-
protos: []T{ProtocolID, proto},
13+
protos: []protoInfo[T]{{ID: ProtocolID}, {ID: proto}},
1414
con: c,
1515

1616
rhandshakeOnce: newOnce(),
@@ -24,11 +24,13 @@ func NewMSSelect2[T StringLike](c io.ReadWriteCloser, proto T, peerProtos []T) L
2424
t.AddProtocol(p)
2525
}
2626

27-
// TODO: use a proper varint instead of a hex string later
28-
abbrv := T(hex.EncodeToString(t.Abbreviate(proto)))
27+
abbrv := t.Abbreviate(proto)
2928
return &lazyClientConn[T]{
30-
protos: []T{ProtocolID, abbrv},
31-
con: c,
29+
protos: []protoInfo[T]{
30+
{ID: ProtocolID, Abbrev: ProtocolAbbrev},
31+
{ID: proto, Abbrev: abbrv},
32+
},
33+
con: c,
3234

3335
rhandshakeOnce: newOnce(),
3436
whandshakeOnce: newOnce(),
@@ -40,7 +42,7 @@ func NewMSSelect2[T StringLike](c io.ReadWriteCloser, proto T, peerProtos []T) L
4042
// NewMSSelect.
4143
func NewMultistream[T StringLike](c io.ReadWriteCloser, proto T) LazyConn {
4244
return &lazyClientConn[T]{
43-
protos: []T{proto},
45+
protos: []protoInfo[T]{{ID: proto}},
4446
con: c,
4547

4648
rhandshakeOnce: newOnce(),
@@ -76,6 +78,11 @@ func (o *once) Do(f func()) {
7678
f()
7779
}
7880

81+
type protoInfo[T StringLike] struct {
82+
ID T
83+
Abbrev []byte
84+
}
85+
7986
// lazyClientConn is a ReadWriteCloser adapter that lazily negotiates a protocol
8087
// using multistream-select on first use.
8188
//
@@ -92,7 +99,7 @@ type lazyClientConn[T StringLike] struct {
9299
werr error
93100

94101
// The sequence of protocols to negotiate.
95-
protos []T
102+
protos []protoInfo[T]
96103

97104
// The inner connection.
98105
con io.ReadWriteCloser
@@ -122,18 +129,22 @@ func (l *lazyClientConn[T]) Read(b []byte) (int, error) {
122129
func (l *lazyClientConn[T]) doReadHandshake() {
123130
for _, proto := range l.protos {
124131
// read protocol
125-
tok, err := ReadNextToken[T](l.con)
132+
tok, err := ReadNextTokenBytes(l.con)
126133
if err != nil {
127134
l.rerr = err
128135
return
129136
}
130137

131-
if tok == "na" {
132-
l.rerr = ErrNotSupported[T]{[]T{proto}}
138+
if bytes.Equal(tok, []byte("na")) {
139+
l.rerr = ErrNotSupported[T]{[]T{proto.ID}}
133140
return
134141
}
135-
if tok != proto {
136-
l.rerr = fmt.Errorf("protocol mismatch in lazy handshake ( %s != %s )", tok, proto)
142+
if proto.Abbrev != nil && !bytes.Equal(tok, proto.Abbrev) {
143+
l.rerr = fmt.Errorf("protocol mismatch in lazy handshake ( %x != %x )", tok, proto.Abbrev)
144+
return
145+
}
146+
if proto.Abbrev == nil && T(tok) != proto.ID {
147+
l.rerr = fmt.Errorf("protocol mismatch in lazy handshake ( %s != %s )", T(tok), proto.ID)
137148
return
138149
}
139150
}
@@ -149,7 +160,11 @@ func (l *lazyClientConn[T]) doWriteHandshakeWithData(extra []byte) int {
149160
defer putWriter(buf)
150161

151162
for _, proto := range l.protos {
152-
l.werr = delimWrite(buf, []byte(proto))
163+
if proto.Abbrev != nil {
164+
l.werr = delimWrite(buf, proto.Abbrev)
165+
} else {
166+
l.werr = delimWrite(buf, []byte(proto.ID))
167+
}
153168
if l.werr != nil {
154169
return 0
155170
}

multistream.go

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ package multistream
55

66
import (
77
"bufio"
8-
"encoding/hex"
8+
"bytes"
99
"errors"
1010
"fmt"
1111
"io"
@@ -26,6 +26,9 @@ var ErrUnknownPrefix = errors.New("unknown protocol hash prefix")
2626
// the multistream muxers on both sides of a channel can work with each other.
2727
const ProtocolID = "/multistream/1.0.0"
2828

29+
// ProtocolID identifies the multistream protocol abbreviation support
30+
var ProtocolAbbrev = []byte{0xff, 0x11}
31+
2932
// Multistream-select version that protocol abbreviation is supported
3033
const AbbrevSupportedMSSVersion = 2
3134

@@ -186,24 +189,6 @@ func (msm *MultistreamMuxer[T]) Protocols() []T {
186189
// fails because of a ProtocolID mismatch.
187190
var ErrIncorrectVersion = errors.New("client connected with incorrect version")
188191

189-
func (msm *MultistreamMuxer[T]) decodeProtocol(s T) (T, error) {
190-
msm.handlerlock.RLock()
191-
defer msm.handlerlock.RUnlock()
192-
193-
bytes, err := hex.DecodeString(string(s))
194-
// TODO: decide whether to compare strings or use abbrevTree by looking at
195-
// multistream version instead.
196-
if err != nil {
197-
return s, nil
198-
}
199-
200-
proto, err := msm.abbrevTree.GetProtocolID(bytes)
201-
if err != nil {
202-
return "", err
203-
}
204-
return proto, nil
205-
}
206-
207192
func (msm *MultistreamMuxer[T]) findHandler(proto T) *Handler[T] {
208193
msm.handlerlock.RLock()
209194
defer msm.handlerlock.RUnlock()
@@ -227,35 +212,49 @@ func (msm *MultistreamMuxer[T]) Negotiate(rwc io.ReadWriteCloser) (proto T, hand
227212
}
228213
}()
229214

230-
// Send the multistream protocol ID
231-
// Ignore the error here. We want the handshake to finish, even if the
232-
// other side has closed this rwc for writing. They may have sent us a
233-
// message and closed. Future writers will get an error anyways.
234-
_ = delimWriteBuffered(rwc, []byte(ProtocolID))
235-
line, err := ReadNextToken[T](rwc)
215+
token, err := ReadNextTokenBytes(rwc)
236216
if err != nil {
237217
return "", nil, err
238218
}
239-
240-
if line != ProtocolID {
219+
supportAbbrev := false
220+
// Send the multistream protocol ID or the mulstream protocol abbreviation
221+
// Ignore the error here. We want the handshake to finish, even if the
222+
// other side has closed this rwc for writing. They may have sent us a
223+
// message and closed. Future writers will get an error anyways.
224+
if bytes.Equal(token, ProtocolAbbrev) {
225+
supportAbbrev = true
226+
_ = delimWriteBuffered(rwc, ProtocolAbbrev)
227+
} else if T(token) == ProtocolID {
228+
_ = delimWriteBuffered(rwc, []byte(ProtocolID))
229+
} else {
241230
rwc.Close()
242231
return "", nil, ErrIncorrectVersion
243232
}
244233

245234
loop:
246235
for {
247236
// Now read and respond to commands until they send a valid protocol id
248-
tok, err := ReadNextToken[T](rwc)
237+
var proto T
238+
239+
tok, err := ReadNextTokenBytes(rwc)
249240
if err != nil {
250241
return "", nil, err
251242
}
252243

253-
p, err := msm.decodeProtocol(tok)
254-
if err != nil {
255-
return "", nil, err
244+
if supportAbbrev {
245+
// decode the protocol abbreviation using the abbreviation tree
246+
msm.handlerlock.RLock()
247+
proto, err = msm.abbrevTree.GetProtocolID(tok)
248+
msm.handlerlock.RUnlock()
249+
250+
if err != nil {
251+
return "", nil, err
252+
}
253+
} else {
254+
proto = T(tok)
256255
}
257256

258-
h := msm.findHandler(p)
257+
h := msm.findHandler(proto)
259258
if h == nil {
260259
if err := delimWriteBuffered(rwc, []byte("na")); err != nil {
261260
return "", nil, err
@@ -266,10 +265,10 @@ loop:
266265
// Ignore the error here. We want the handshake to finish, even if the
267266
// other side has closed this rwc for writing. They may have sent us a
268267
// message and closed. Future writers will get an error anyways.
269-
_ = delimWriteBuffered(rwc, []byte(tok))
268+
_ = delimWriteBuffered(rwc, tok)
270269

271270
// hand off processing to the sub-protocol handler
272-
return p, h.Handle, nil
271+
return proto, h.Handle, nil
273272
}
274273

275274
}

0 commit comments

Comments
 (0)