Skip to content

Commit f56b1c3

Browse files
committed
add support for sending error codes on stream reset
1 parent 8adb9a8 commit f56b1c3

File tree

3 files changed

+93
-9
lines changed

3 files changed

+93
-9
lines changed

const.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,27 @@ func (e *GoAwayError) Is(target error) bool {
5757
return false
5858
}
5959

60+
// A StreamError is used for errors returned from Read and Write calls after the stream is Reset
61+
type StreamError struct {
62+
ErrorCode uint32
63+
Remote bool
64+
}
65+
66+
func (s *StreamError) Error() string {
67+
if s.Remote {
68+
return fmt.Sprintf("stream reset by remote, error code: %d", s.ErrorCode)
69+
}
70+
return fmt.Sprintf("stream reset, error code: %d", s.ErrorCode)
71+
}
72+
73+
func (s *StreamError) Is(target error) bool {
74+
if target == ErrStreamReset {
75+
return true
76+
}
77+
e, ok := target.(*StreamError)
78+
return ok && *e == *s
79+
}
80+
6081
var (
6182
// ErrInvalidVersion means we received a frame with an
6283
// invalid version

session_test.go

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"testing"
1717
"time"
1818

19+
"github.com/stretchr/testify/assert"
1920
"github.com/stretchr/testify/require"
2021
)
2122

@@ -1571,6 +1572,56 @@ func TestStreamResetRead(t *testing.T) {
15711572
wc.Wait()
15721573
}
15731574

1575+
func TestStreamResetWithError(t *testing.T) {
1576+
client, server := testClientServer()
1577+
defer client.Close()
1578+
defer server.Close()
1579+
1580+
wc := new(sync.WaitGroup)
1581+
wc.Add(2)
1582+
go func() {
1583+
defer wc.Done()
1584+
stream, err := server.AcceptStream()
1585+
if err != nil {
1586+
t.Error(err)
1587+
}
1588+
1589+
se := &StreamError{}
1590+
_, err = io.ReadAll(stream)
1591+
if !errors.As(err, &se) {
1592+
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
1593+
return
1594+
}
1595+
expected := &StreamError{Remote: true, ErrorCode: 42}
1596+
assert.Equal(t, se, expected)
1597+
}()
1598+
1599+
stream, err := client.OpenStream(context.Background())
1600+
if err != nil {
1601+
t.Error(err)
1602+
}
1603+
1604+
go func() {
1605+
defer wc.Done()
1606+
1607+
se := &StreamError{}
1608+
_, err := io.ReadAll(stream)
1609+
if !errors.As(err, &se) {
1610+
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
1611+
return
1612+
}
1613+
expected := &StreamError{Remote: false, ErrorCode: 42}
1614+
assert.Equal(t, se, expected)
1615+
}()
1616+
1617+
time.Sleep(1 * time.Second)
1618+
err = stream.ResetWithError(42)
1619+
if err != nil {
1620+
t.Fatal(err)
1621+
}
1622+
wc.Wait()
1623+
}
1624+
15741625
func TestLotsOfWritesWithStreamDeadline(t *testing.T) {
15751626
config := testConf()
15761627
config.EnableKeepAlive = false
@@ -1809,7 +1860,7 @@ func TestMaxIncomingStreams(t *testing.T) {
18091860
require.NoError(t, err)
18101861
str.SetDeadline(time.Now().Add(time.Second))
18111862
_, err = str.Read([]byte{0})
1812-
require.EqualError(t, err, "stream reset")
1863+
require.ErrorIs(t, err, ErrStreamReset)
18131864

18141865
// Now close one of the streams.
18151866
// This should then allow the client to open a new stream.

stream.go

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ type Stream struct {
4242
state streamState
4343
writeState, readState halfStreamState
4444
stateLock sync.Mutex
45+
resetErr *StreamError
4546

4647
recvBuf segmentedBuffer
4748

@@ -89,6 +90,7 @@ func (s *Stream) Read(b []byte) (n int, err error) {
8990
START:
9091
s.stateLock.Lock()
9192
state := s.readState
93+
resetErr := s.resetErr
9294
s.stateLock.Unlock()
9395

9496
switch state {
@@ -101,7 +103,7 @@ START:
101103
}
102104
// Closed, but we have data pending -> read.
103105
case halfReset:
104-
return 0, ErrStreamReset
106+
return 0, resetErr
105107
default:
106108
panic("unknown state")
107109
}
@@ -147,6 +149,7 @@ func (s *Stream) write(b []byte) (n int, err error) {
147149
START:
148150
s.stateLock.Lock()
149151
state := s.writeState
152+
resetErr := s.resetErr
150153
s.stateLock.Unlock()
151154

152155
switch state {
@@ -155,7 +158,7 @@ START:
155158
case halfClosed:
156159
return 0, ErrStreamClosed
157160
case halfReset:
158-
return 0, ErrStreamReset
161+
return 0, resetErr
159162
default:
160163
panic("unknown state")
161164
}
@@ -250,13 +253,17 @@ func (s *Stream) sendClose() error {
250253
}
251254

252255
// sendReset is used to send a RST
253-
func (s *Stream) sendReset() error {
254-
hdr := encode(typeWindowUpdate, flagRST, s.id, 0)
256+
func (s *Stream) sendReset(errCode uint32) error {
257+
hdr := encode(typeWindowUpdate, flagRST, s.id, errCode)
255258
return s.session.sendMsg(hdr, nil, nil)
256259
}
257260

258261
// Reset resets the stream (forcibly closes the stream)
259262
func (s *Stream) Reset() error {
263+
return s.ResetWithError(0)
264+
}
265+
266+
func (s *Stream) ResetWithError(errCode uint32) error {
260267
sendReset := false
261268
s.stateLock.Lock()
262269
switch s.state {
@@ -281,10 +288,11 @@ func (s *Stream) Reset() error {
281288
s.readState = halfReset
282289
}
283290
s.state = streamFinished
291+
s.resetErr = &StreamError{Remote: false, ErrorCode: errCode}
284292
s.notifyWaiting()
285293
s.stateLock.Unlock()
286294
if sendReset {
287-
_ = s.sendReset()
295+
_ = s.sendReset(errCode)
288296
}
289297
s.cleanup()
290298
return nil
@@ -382,7 +390,7 @@ func (s *Stream) cleanup() {
382390

383391
// processFlags is used to update the state of the stream
384392
// based on set flags, if any. Lock must be held
385-
func (s *Stream) processFlags(flags uint16) {
393+
func (s *Stream) processFlags(flags uint16, hdr header) {
386394
// Close the stream without holding the state lock
387395
var closeStream bool
388396
defer func() {
@@ -425,6 +433,10 @@ func (s *Stream) processFlags(flags uint16) {
425433
s.writeState = halfReset
426434
}
427435
s.state = streamFinished
436+
// Length in a window update frame with RST flag encodes an error code.
437+
if hdr.MsgType() == typeWindowUpdate && s.resetErr == nil {
438+
s.resetErr = &StreamError{Remote: true, ErrorCode: hdr.Length()}
439+
}
428440
s.stateLock.Unlock()
429441
closeStream = true
430442
s.notifyWaiting()
@@ -439,15 +451,15 @@ func (s *Stream) notifyWaiting() {
439451

440452
// incrSendWindow updates the size of our send window
441453
func (s *Stream) incrSendWindow(hdr header, flags uint16) {
442-
s.processFlags(flags)
454+
s.processFlags(flags, hdr)
443455
// Increase window, unblock a sender
444456
atomic.AddUint32(&s.sendWindow, hdr.Length())
445457
asyncNotify(s.sendNotifyCh)
446458
}
447459

448460
// readData is used to handle a data frame
449461
func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
450-
s.processFlags(flags)
462+
s.processFlags(flags, hdr)
451463

452464
// Check that our recv window is not exceeded
453465
length := hdr.Length()

0 commit comments

Comments
 (0)