Skip to content

Commit b724994

Browse files
committed
nonce should be tracked by conn
1 parent e2790a3 commit b724994

File tree

4 files changed

+43
-37
lines changed

4 files changed

+43
-37
lines changed

conn.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ type Conn struct {
3636

3737
closed int32
3838
onCloseErrorCB func(c *Conn)
39+
40+
nonceIdx uint64
41+
peerNonceIdx uint64
3942
}
4043

4144
func (c *Conn) Close() error {
@@ -518,3 +521,27 @@ func (conn *Conn) notifyOnCloseError() {
518521
}
519522
conn.onCloseErrorCB(conn)
520523
}
524+
525+
func (conn *Conn) SetNonce(i uint64) {
526+
conn.nonceIdx = i
527+
}
528+
529+
func (conn *Conn) Nonce() uint64 {
530+
return conn.nonceIdx
531+
}
532+
533+
func (conn *Conn) SetPeerNonce(i uint64) {
534+
conn.peerNonceIdx = i
535+
}
536+
537+
func (conn *Conn) PeerNonce() uint64 {
538+
return conn.peerNonceIdx
539+
}
540+
541+
func (conn *Conn) IncrNonce() {
542+
conn.nonceIdx++
543+
}
544+
545+
func (conn *Conn) IncrPeerNonce() {
546+
conn.peerNonceIdx++
547+
}

security.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ type Security interface {
2424

2525
type SecurityEncryption interface {
2626
// Encrypt writes the encrypted form of data to w.
27-
Encrypt(data []byte, more bool) ([]byte, error)
27+
Encrypt(conn *Conn, data []byte, more bool) ([]byte, error)
2828

2929
// Decrypt writes the decrypted form of data to w.
30-
Decrypt(data []byte) ([]byte, bool, error)
30+
Decrypt(conn *Conn, data []byte) ([]byte, bool, error)
3131
}
3232

3333
// SecurityType denotes types of ZMTP security mechanisms

security/curve/curve.go

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ type security struct {
5050
secretKey [keySize]byte // Derived shared secret key
5151
asServer bool // True if this is a server
5252
sharedKey [keySize]byte // Pre-computed shared key (for optimization)
53-
nonceIdx uint64
54-
peerNonceIdx uint64
5553
}
5654

5755
// SecurityForClient returns a CURVE security mechanism for a client.
@@ -123,8 +121,8 @@ func (sec *security) clientHandshake(conn *zmq4.Conn) error {
123121
return fmt.Errorf("security/curve: failed READY: %w", err)
124122
}
125123

126-
sec.nonceIdx = 3
127-
sec.peerNonceIdx = 1
124+
conn.SetNonce(3)
125+
conn.SetPeerNonce(1)
128126

129127
box.Precompute(&sec.sharedKey, &sec.secretKey, &sec.ephemeral.Private)
130128

@@ -168,26 +166,27 @@ func (sec *security) serverHandshake(conn *zmq4.Conn) error {
168166
return fmt.Errorf("security/curve: Server ready failed: %w", err)
169167
}
170168

171-
sec.nonceIdx = 2
172-
sec.peerNonceIdx = 2
169+
conn.SetNonce(2)
170+
conn.SetPeerNonce(2)
171+
173172
box.Precompute(&sec.sharedKey, &clientTransPubKey, &kp.Private)
174173
return nil
175174
}
176175

177176
// Encrypt writes the encrypted form of data to w.
178-
func (sec *security) Encrypt(data []byte, more bool) ([]byte, error) {
179-
defer func() { sec.nonceIdx++ }()
177+
func (sec *security) Encrypt(conn *zmq4.Conn, data []byte, more bool) ([]byte, error) {
178+
defer func() { conn.IncrNonce() }()
180179
out := make([]byte, 8+8+17+len(data))
181180
out[0] = uint8(7)
182181
copy(out[1:], "MESSAGE")
183182

184183
var nonce Nonce
185184
if sec.asServer {
186-
nonce.Short("CurveZMQMESSAGES", sec.nonceIdx) // From server
185+
nonce.Short("CurveZMQMESSAGES", conn.Nonce()) // From server
187186
} else {
188-
nonce.Short("CurveZMQMESSAGEC", sec.nonceIdx) // From client
187+
nonce.Short("CurveZMQMESSAGEC", conn.Nonce()) // From client
189188
}
190-
binary.BigEndian.AppendUint64(out[8:8], sec.nonceIdx)
189+
binary.BigEndian.AppendUint64(out[8:8], conn.Nonce())
191190
toSeal := make([]byte, 1+len(data))
192191
if more {
193192
toSeal[0] = 0x1
@@ -198,7 +197,7 @@ func (sec *security) Encrypt(data []byte, more bool) ([]byte, error) {
198197
}
199198

200199
// Decrypt writes the decrypted form of data to w.
201-
func (sec *security) Decrypt(body []byte) ([]byte, bool, error) {
200+
func (sec *security) Decrypt(conn *zmq4.Conn, body []byte) ([]byte, bool, error) {
202201
if len(body) < 33 {
203202
return nil, false, fmt.Errorf("security/curve: invalid message: too short")
204203
}
@@ -217,10 +216,10 @@ func (sec *security) Decrypt(body []byte) ([]byte, bool, error) {
217216
} else {
218217
nonce.Short("CurveZMQMESSAGES", shortNonce) // From server
219218
}
220-
if shortNonce != sec.peerNonceIdx+1 {
221-
return nil, false, fmt.Errorf("Peer used invalid nonce (expected %d, got %d)", sec.peerNonceIdx+1, shortNonce)
219+
if shortNonce != conn.PeerNonce()+1 {
220+
return nil, false, fmt.Errorf("Peer used invalid nonce (expected %d, got %d)", conn.PeerNonce()+1, shortNonce)
222221
}
223-
sec.peerNonceIdx++
222+
conn.IncrPeerNonce()
224223
copy(nonce[16:], body[8:])
225224
out := make([]byte, len(body)-32)
226225
out, ok := box.OpenAfterPrecomputation(out[0:0], body[16:], nonce.N(), &sec.sharedKey)

security_test.go

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
package zmq4
66

77
import (
8-
"bytes"
98
"context"
109
"fmt"
1110
"os"
@@ -22,25 +21,6 @@ func TestNullSecurity(t *testing.T) {
2221
if got, want := sec.Type(), NullSecurity; got != want {
2322
t.Fatalf("got=%v, want=%v", got, want)
2423
}
25-
26-
data := []byte("hello world")
27-
wenc := new(bytes.Buffer)
28-
if _, err := sec.Encrypt(wenc, data); err != nil {
29-
t.Fatalf("error encrypting data: %+v", err)
30-
}
31-
32-
if !bytes.Equal(wenc.Bytes(), data) {
33-
t.Fatalf("error encrypted data.\ngot = %q\nwant= %q\n", wenc.Bytes(), data)
34-
}
35-
36-
wdec := new(bytes.Buffer)
37-
if _, err := sec.Decrypt(wdec, wenc.Bytes()); err != nil {
38-
t.Fatalf("error decrypting data: %+v", err)
39-
}
40-
41-
if !bytes.Equal(wdec.Bytes(), data) {
42-
t.Fatalf("error decrypted data.\ngot = %q\nwant= %q\n", wdec.Bytes(), data)
43-
}
4424
}
4525

4626
func TestNullHandshakeReqRep(t *testing.T) {

0 commit comments

Comments
 (0)