Skip to content

Commit 21c4378

Browse files
committed
Make DNS responder thread safe
1 parent e8f6763 commit 21c4378

File tree

3 files changed

+65
-37
lines changed

3 files changed

+65
-37
lines changed

pkg/registrars/dns-registrar/tworeqresp/requester.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,7 @@ func splitIntoChunks(data []byte, mtu int) [][]byte {
7171
var chunks [][]byte
7272

7373
for i := 0; i < len(data); i += mtu {
74-
end := i + mtu
75-
76-
if end > len(data) {
77-
end = len(data)
78-
}
74+
end := min(i+mtu, len(data))
7975

8076
chunks = append(chunks, data[i:end])
8177
}

pkg/registrars/dns-registrar/tworeqresp/requester_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func TestSpliting(t *testing.T) {
5151
if err != nil {
5252
t.Fatalf("error creating requester: %v", err)
5353
}
54-
requester.RequestAndRecv(testCase.data)
54+
_, _ = requester.RequestAndRecv(testCase.data)
5555
if parent.calls != testCase.chunksExpected {
5656
t.Fatalf("calls: %v, expected: %v", parent.calls, testCase.chunksExpected)
5757
}

pkg/registrars/dns-registrar/tworeqresp/responder.go

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package tworeqresp
33
import (
44
"bytes"
55
"fmt"
6+
"sync"
67
"time"
78

89
pb "github.com/refraction-networking/conjure/proto"
@@ -19,33 +20,44 @@ type oneresponder interface {
1920
type Responder struct {
2021
parent oneresponder
2122
parts map[[idLen]byte]*timedData
23+
mutex sync.Mutex
2224
}
2325

2426
func NewResponder(parent oneresponder) (*Responder, error) {
25-
return &Responder{
27+
r := &Responder{
2628
parent: parent,
2729
parts: make(map[[idLen]byte]*timedData),
28-
}, nil
29-
}
30-
31-
type timedData struct {
32-
data [][]byte
33-
expiry time.Time
30+
mutex: sync.Mutex{},
31+
}
32+
go r.gc()
33+
return r, nil
3434
}
3535

36-
func (r *Responder) RecvAndRespond(parentGetResponse func([]byte) ([]byte, error)) error {
36+
func (r *Responder) gc() {
3737
ticker := time.NewTicker(interval)
38-
getResponse := func(data []byte) ([]byte, error) {
39-
select {
40-
case <-ticker.C:
38+
39+
for range ticker.C {
40+
func() {
41+
r.mutex.Lock()
42+
defer r.mutex.Unlock()
4143
for key, data := range r.parts {
4244
if time.Now().After(data.expiry) {
4345
delete(r.parts, key)
4446
}
4547
}
46-
default:
47-
}
48+
}()
49+
}
50+
51+
}
4852

53+
type timedData struct {
54+
data [][]byte
55+
expiry time.Time
56+
mutex sync.Mutex
57+
}
58+
59+
func (r *Responder) RecvAndRespond(parentGetResponse func([]byte) ([]byte, error)) error {
60+
getResponse := func(data []byte) ([]byte, error) {
4961
partIn := &pb.DnsPartReq{}
5062
err := proto.Unmarshal(data, partIn)
5163
if err != nil {
@@ -58,29 +70,53 @@ func (r *Responder) RecvAndRespond(parentGetResponse func([]byte) ([]byte, error
5870

5971
partId := (*[idLen]byte)(partIn.GetId())
6072

61-
if _, ok := r.parts[*partId]; !ok {
73+
r.mutex.Lock()
74+
regData, ok := r.parts[*partId]
75+
76+
if !ok {
6277
r.parts[*partId] = &timedData{
6378
data: make([][]byte, partIn.GetTotalParts()),
6479
expiry: time.Now().Add(interval),
80+
mutex: sync.Mutex{},
6581
}
82+
regData = r.parts[*partId]
6683
}
84+
r.mutex.Unlock()
6785

68-
if int(partIn.GetTotalParts()) != len(r.parts[*partId].data) {
69-
return nil, fmt.Errorf("invalid total parts")
70-
}
86+
buf, waiting, err := func() ([]byte, bool, error) {
87+
regData.mutex.Lock()
88+
defer regData.mutex.Unlock()
89+
if int(partIn.GetTotalParts()) != len(regData.data) {
90+
return nil, false, fmt.Errorf("invalid total parts")
91+
}
7192

72-
if int(partIn.GetPartNum()) >= len(r.parts[*partId].data) {
73-
return nil, fmt.Errorf("part number out of bound")
74-
}
93+
if int(partIn.GetPartNum()) >= len(regData.data) {
94+
return nil, false, fmt.Errorf("part number out of bound")
95+
}
7596

76-
r.parts[*partId].data[partIn.GetPartNum()] = partIn.GetData()
97+
regData.data[partIn.GetPartNum()] = partIn.GetData()
7798

78-
waiting := false
79-
for _, part := range r.parts[*partId].data {
80-
if part == nil {
81-
waiting = true
82-
break
99+
waiting := false
100+
for _, part := range regData.data {
101+
if part == nil {
102+
waiting = true
103+
break
104+
}
105+
}
106+
if waiting {
107+
return nil, true, nil
108+
}
109+
110+
var buffer bytes.Buffer
111+
for _, part := range regData.data {
112+
buffer.Write(part)
83113
}
114+
115+
return buffer.Bytes(), false, nil
116+
}()
117+
118+
if err != nil {
119+
return nil, err
84120
}
85121

86122
if waiting {
@@ -93,11 +129,7 @@ func (r *Responder) RecvAndRespond(parentGetResponse func([]byte) ([]byte, error
93129
return respBytes, nil
94130
}
95131

96-
var buffer bytes.Buffer
97-
for _, part := range r.parts[*partId].data {
98-
buffer.Write(part)
99-
}
100-
res, err := parentGetResponse(buffer.Bytes())
132+
res, err := parentGetResponse(buf)
101133
if err != nil {
102134
return nil, fmt.Errorf("error from parent getResponse: %v", err)
103135
}

0 commit comments

Comments
 (0)