diff --git a/conn.go b/conn.go index 2da23bd..b7e8f14 100644 --- a/conn.go +++ b/conn.go @@ -5,7 +5,6 @@ package zmq4 import ( - "bytes" "encoding/binary" "errors" "fmt" @@ -28,8 +27,9 @@ type Conn struct { Server bool Meta Metadata Peer struct { - Server bool - Meta Metadata + Server bool + Meta Metadata + NonceIdx uint64 } mu sync.RWMutex @@ -37,6 +37,9 @@ type Conn struct { closed int32 onCloseErrorCB func(c *Conn) + + NonceIdx uint64 + SharedKey *[32]byte } func (c *Conn) Close() error { @@ -61,6 +64,14 @@ func (c *Conn) Write(p []byte) (int, error) { return n, err } +func (c *Conn) LocalAddr() net.Addr { + return c.rw.LocalAddr() +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.rw.RemoteAddr() +} + // Open opens a ZMTP connection over rw with the given security, socket type and identity. // An optional onCloseErrorCB can be provided to inform the caller when this Conn is closed. // Open performs a complete ZMTP handshake. @@ -169,7 +180,7 @@ func (c *Conn) SendCmd(name string, body []byte) error { if err != nil { return err } - return c.send(true, buf, 0) + return c.send(true, buf, false) } // SendMsg sends a ZMTP message over the wire. @@ -183,11 +194,11 @@ func (c *Conn) SendMsg(msg Msg) error { nframes := len(msg.Frames) for i, frame := range msg.Frames { - var flag byte + var more bool if i < nframes-1 { - flag ^= hasMoreBitFlag + more = true } - err := c.send(false, frame, flag) + err := c.send(false, frame, more) if err != nil { return fmt.Errorf("zmq4: error sending frame %d/%d: %w", i+1, nframes, err) } @@ -286,8 +297,18 @@ func (c *Conn) sendMulti(msg Msg) error { nframes := len(msg.Frames) for i, frame := range msg.Frames { var flag byte + var more bool if i < nframes-1 { flag ^= hasMoreBitFlag + more = true + } + + if sec, ok := c.sec.(SecurityEncryption); ok { + encrypt, err := sec.Encrypt(c, frame, more) + if err != nil { + return err + } + frame = encrypt } size := len(frame) @@ -308,16 +329,7 @@ func (c *Conn) sendMulti(msg Msg) error { hdr[1] = uint8(size) } - switch c.sec.Type() { - case NullSecurity: - buffers = append(buffers, hdr[:hsz], frame) - default: - var secBuf bytes.Buffer - if _, err := c.sec.Encrypt(&secBuf, frame); err != nil { - return err - } - buffers = append(buffers, hdr[:hsz], secBuf.Bytes()) - } + buffers = append(buffers, hdr[:hsz], frame) } if _, err := buffers.WriteTo(c.rw); err != nil { @@ -328,7 +340,21 @@ func (c *Conn) sendMulti(msg Msg) error { return nil } -func (c *Conn) send(isCommand bool, body []byte, flag byte) error { +func (c *Conn) send(isCommand bool, body []byte, more bool) error { + var flag byte + + // commands should not be encrypted. + if sec, ok := c.sec.(SecurityEncryption); ok && !isCommand { + encrypt, err := sec.Encrypt(c, body, more) + if err != nil { + c.checkIO(err) + return err + } + body = encrypt + } else if more { + flag ^= hasMoreBitFlag + } + // Long flag size := len(body) isLong := size > 255 @@ -358,7 +384,7 @@ func (c *Conn) send(isCommand bool, body []byte, flag byte) error { return err } - if _, err := c.sec.Encrypt(c.rw, body); err != nil { + if _, err := c.rw.Write(body); err != nil { c.checkIO(err) return err } @@ -427,11 +453,16 @@ func (c *Conn) read() Msg { continue } - buf := new(bytes.Buffer) - if _, msg.err = c.sec.Decrypt(buf, body); msg.err != nil { - return msg + if sec, ok := c.sec.(SecurityEncryption); ok && !isCmd { + decrypt, more, err := sec.Decrypt(c, body) + if err != nil { + msg.err = err + return msg + } + body = decrypt + hasMore = more } - msg.Frames = append(msg.Frames, buf.Bytes()) + msg.Frames = append(msg.Frames, body) } if isCmd { msg.Type = CmdMsg diff --git a/go.mod b/go.mod index a015ad2..a259886 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,15 @@ module github.com/go-zeromq/zmq4 -go 1.21 +go 1.23.0 + +toolchain go1.24.2 require ( github.com/go-zeromq/goczmq/v4 v4.2.2 go.uber.org/goleak v1.3.0 - golang.org/x/sync v0.7.0 - golang.org/x/text v0.15.0 + golang.org/x/crypto v0.38.0 + golang.org/x/sync v0.14.0 + golang.org/x/text v0.25.0 ) + +require golang.org/x/sys v0.33.0 // indirect diff --git a/go.sum b/go.sum index d204e12..ef05c1d 100644 --- a/go.sum +++ b/go.sum @@ -8,9 +8,13 @@ github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PK github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= +golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/security.go b/security.go index 959df89..1eb28da 100644 --- a/security.go +++ b/security.go @@ -6,7 +6,6 @@ package zmq4 import ( "fmt" - "io" ) // Security is an interface for ZMTP security mechanisms @@ -21,12 +20,14 @@ type Security interface { // https://rfc.zeromq.org/spec:24/ZMTP-PLAIN/ // https://rfc.zeromq.org/spec:25/ZMTP-CURVE/ Handshake(conn *Conn, server bool) error +} +type SecurityEncryption interface { // Encrypt writes the encrypted form of data to w. - Encrypt(w io.Writer, data []byte) (int, error) + Encrypt(conn *Conn, data []byte, more bool) ([]byte, error) // Decrypt writes the decrypted form of data to w. - Decrypt(w io.Writer, data []byte) (int, error) + Decrypt(conn *Conn, data []byte) ([]byte, bool, error) } // SecurityType denotes types of ZMTP security mechanisms @@ -90,16 +91,6 @@ func (nullSecurity) Handshake(conn *Conn, server bool) error { return nil } -// Encrypt writes the encrypted form of data to w. -func (nullSecurity) Encrypt(w io.Writer, data []byte) (int, error) { - return w.Write(data) -} - -// Decrypt writes the decrypted form of data to w. -func (nullSecurity) Decrypt(w io.Writer, data []byte) (int, error) { - return w.Write(data) -} - var ( _ Security = (*nullSecurity)(nil) ) diff --git a/security/curve/curve.go b/security/curve/curve.go new file mode 100644 index 0000000..5d34a97 --- /dev/null +++ b/security/curve/curve.go @@ -0,0 +1,481 @@ +// Copyright 2024 The go-zeromq Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package curve provides the ZeroMQ CURVE security mechanism as specified by: +// https://rfc.zeromq.org/spec:25/ZMTP-CURVE/ +package curve + +import ( + "crypto/rand" + "encoding/binary" + "fmt" + "io" + "unsafe" + + "github.com/go-zeromq/zmq4" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/nacl/box" + "golang.org/x/crypto/nacl/secretbox" +) + +const ( + keySize = 32 // Size of public and private keyFunc + nonceSize = 24 // Size of nonce +) + +// KeyPair represents a CurveZMQ key pair. +type KeyPair struct { + Public [keySize]byte + Private [keySize]byte +} + +// NewKeyPair generates a new random keypair for curve security. +func NewKeyPair() (*KeyPair, error) { + var kp KeyPair + pub, priv, err := box.GenerateKey(rand.Reader) + if err != nil { + return nil, err + } + copy(kp.Public[:], pub[:]) + copy(kp.Private[:], priv[:]) + return &kp, nil +} + +// security implements the CURVE security mechanism. +type security struct { + serverPubKey [keySize]byte // Long-term server public key + clientKeyPair *KeyPair // client KeyPair + keyFunc func(clientKey *[32]byte) (*KeyPair, error) // Func for retrieving server's KeyPair + asServer bool // True if this is a server +} + +// SecurityForClient returns a CURVE security mechanism for a client. +// The client must know the server's public key. +func SecurityForClient(serverKey [keySize]byte, clientKeys *KeyPair) zmq4.Security { + sec := &security{ + serverPubKey: serverKey, + clientKeyPair: clientKeys, + asServer: false, + } + return sec +} + +// SecurityForServer returns a CURVE security mechanism for a server. +// The server must have its own key pair. +func SecurityForServer(serverKeys *KeyPair) zmq4.Security { + sec := &security{ + keyFunc: func(_ *[32]byte) (*KeyPair, error) { + return serverKeys, nil + }, + asServer: true, + } + return sec +} + +// SecurityForServerFunc returns a CURVE security mechanism for a server. +// The server must have its own key pair. +func SecurityForServerFunc(keyFunc func(*[32]byte) (*KeyPair, error)) zmq4.Security { + sec := &security{ + keyFunc: keyFunc, + asServer: true, + } + return sec +} + +// Type returns the security mechanism type. +func (security) Type() zmq4.SecurityType { + return zmq4.CurveSecurity +} + +// Handshake implements the ZMTP security handshake according to +// the CURVE security mechanism. +// see: https://rfc.zeromq.org/spec:25/ZMTP-CURVE/ +func (sec *security) Handshake(conn *zmq4.Conn, server bool) error { + if server != sec.asServer { + return fmt.Errorf("security/curve: invalid server flag, got=%v, want=%v", server, sec.asServer) + } + + // Create new ephemeral key pair for this connection + var err error + ephemeral, err := NewKeyPair() + if err != nil { + return fmt.Errorf("security/curve: could not generate session keypair: %w", err) + } + + if server { + return sec.serverHandshake(conn) + } + return sec.clientHandshake(conn, ephemeral) +} + +func (sec *security) clientHandshake(conn *zmq4.Conn, ephemeral *KeyPair) error { + var nonce Nonce + err := sec.doHello(conn, &nonce, ephemeral) + if err != nil { + return fmt.Errorf("security/curve: could not send HELLO to server: %w", err) + } + + servCookie, secretKey, err := sec.doWelcome(&nonce, conn, ephemeral) + if err != nil { + return fmt.Errorf("security/curve: failed WELCOME: %w", err) + } + + err = sec.doInitiate(conn, servCookie, &nonce, sec.clientKeyPair, secretKey, ephemeral) + if err != nil { + return fmt.Errorf("security/curve: failed INITIATE: %w", err) + } + + servMeta, err := sec.doReady(conn, &nonce, secretKey, ephemeral) + if err != nil { + return fmt.Errorf("security/curve: failed READY: %w", err) + } + + conn.NonceIdx = 3 + conn.Peer.NonceIdx = 1 + var sharedKey [32]byte + box.Precompute(&sharedKey, secretKey, &ephemeral.Private) + conn.SharedKey = &sharedKey + + // Unmarshal the server metadata + err = conn.Peer.Meta.UnmarshalZMTP(servMeta) + if err != nil { + return fmt.Errorf("security/curve: could not unmarshal server metadata: %w", err) + } + + return nil +} + +func (sec *security) serverHandshake(conn *zmq4.Conn) error { + var nonce Nonce + var cookieKey [32]byte + + clientTransPubKey, err := sec.doServerHello(&nonce, conn) + if err != nil { + return fmt.Errorf("security/curve: Client hello failed: %w", err) + } + + kp, err := NewKeyPair() + if err != nil { + panic(fmt.Sprintf("Failed creating cookie key: %s", err.Error())) + } + err = sec.doServerWelcome(&nonce, conn, &clientTransPubKey, &cookieKey, kp) + if err != nil { + return fmt.Errorf("security/curve: Failed sending welcome: %w", err) + } + + clientMeta, err := sec.doServerInitiate(&nonce, conn, &cookieKey, &clientTransPubKey, &kp.Private) + if err != nil { + return fmt.Errorf("security/curve: Client initiate failed: %w", err) + } + err = conn.Peer.Meta.UnmarshalZMTP(clientMeta) + if err != nil { + return fmt.Errorf("security/curve: Could not unmarshal client metadata: %w", err) + } + + err = sec.doServerReady(conn, &clientTransPubKey, &kp.Private) + if err != nil { + return fmt.Errorf("security/curve: Server ready failed: %w", err) + } + + conn.NonceIdx = 2 + conn.Peer.NonceIdx = 2 + var sharedKey [32]byte + box.Precompute(&sharedKey, &clientTransPubKey, &kp.Private) + conn.SharedKey = &sharedKey + return nil +} + +// Encrypt writes the encrypted form of data to w. +func (sec *security) Encrypt(conn *zmq4.Conn, data []byte, more bool) ([]byte, error) { + defer func() { conn.NonceIdx++ }() + out := make([]byte, 8+8+17+len(data)) + out[0] = uint8(7) + copy(out[1:], "MESSAGE") + + var nonce Nonce + if sec.asServer { + nonce.Short("CurveZMQMESSAGES", conn.NonceIdx) // From server + } else { + nonce.Short("CurveZMQMESSAGEC", conn.NonceIdx) // From client + } + binary.BigEndian.AppendUint64(out[8:8], conn.NonceIdx) + toSeal := make([]byte, 1+len(data)) + if more { + toSeal[0] = 0x1 + } + copy(toSeal[1:], data) + box.SealAfterPrecomputation(out[16:16], toSeal, nonce.N(), conn.SharedKey) + return out, nil +} + +// Decrypt writes the decrypted form of data to w. +func (sec *security) Decrypt(conn *zmq4.Conn, body []byte) ([]byte, bool, error) { + if len(body) < 33 { + return nil, false, fmt.Errorf("security/curve: invalid message: too short") + } + if body[0] != 7 { + return nil, false, fmt.Errorf("security/curve: expected command name to have 7 bytes, got %d", body[0]) + } + nameStr := unsafe.String(&body[1], 7) + if nameStr != "MESSAGE" { + return nil, false, fmt.Errorf("security/curve: expected MESSAGE command, got %s", nameStr) + } + + shortNonce := binary.BigEndian.Uint64(body[8:]) + var nonce Nonce + if sec.asServer { + nonce.Short("CurveZMQMESSAGEC", shortNonce) // From client + } else { + nonce.Short("CurveZMQMESSAGES", shortNonce) // From server + } + if shortNonce != conn.Peer.NonceIdx+1 { + return nil, false, fmt.Errorf("Peer used invalid nonce (expected %d, got %d)", conn.Peer.NonceIdx+1, shortNonce) + } + conn.Peer.NonceIdx++ + copy(nonce[16:], body[8:]) + out := make([]byte, len(body)-32) + out, ok := box.OpenAfterPrecomputation(out[0:0], body[16:], nonce.N(), conn.SharedKey) + if !ok { + return nil, false, fmt.Errorf("Failed opening message box") + } + more := (out[0] & 0x1) == 1 + out = out[1:] // remove "more" flag + + return out, more, nil +} + +func (sec *security) doHello(conn *zmq4.Conn, nonce *Nonce, ephemeral *KeyPair) error { + body := make([]byte, 194) + body[0] = 1 // version + copy(body[74:106], ephemeral.Public[:]) + body[113] = 1 + nonce.Short("CurveZMQHELLO---", 1) + var sigBox [64]byte + box.Seal(body[114:114], sigBox[:], nonce.N(), &sec.serverPubKey, &ephemeral.Private) + return conn.SendCmd(zmq4.CmdHello, body) +} + +func (sec *security) doWelcome(nonce *Nonce, conn *zmq4.Conn, ephemeral *KeyPair) ([]byte, *[32]byte, error) { + cmd, err := conn.RecvCmd() + if err != nil { + return nil, nil, fmt.Errorf("security/curve: could not receive WELCOME from server: %w", err) + } + if cmd.Name != zmq4.CmdWelcome { + return nil, nil, fmt.Errorf("security/curve: expected WELCOME command, got %s", cmd.Name) + } + if len(cmd.Body) != 160 { + return nil, nil, fmt.Errorf("security/curve: expected WELCOME body to be 160 bytes long") + } + + nonce.FromLong("WELCOME-", cmd.Body[:16]) + welcomeBox := make([]byte, 128) + _, ok := box.Open(welcomeBox[0:0], cmd.Body[16:], nonce.N(), &sec.serverPubKey, &ephemeral.Private) + if !ok { + return nil, nil, fmt.Errorf("Failed opening welcome box") + } + + var secretKey [32]byte + copy(secretKey[:], welcomeBox[:32]) + return welcomeBox[32:], &secretKey, nil +} + +func (sec *security) doInitiate(conn *zmq4.Conn, servCookie []byte, nonce *Nonce, keys *KeyPair, secretKey *[32]byte, ephemeral *KeyPair) error { + meta, err := conn.Meta.MarshalZMTP() + if err != nil { + return fmt.Errorf("security/curve: could not marshal metadata: %w", err) + } + initiateBody := make([]byte, 96+8+32+96+len(meta)+16) + copy(initiateBody[:96], servCookie) + initiateBody[103] = 2 + + // initiate::vouch + nonce.Long("VOUCH---") + vouch := make([]byte, 64) + copy(vouch, ephemeral.Public[:]) + copy(vouch[32:], sec.serverPubKey[:]) + vouchBox := make([]byte, 80) + box.Seal(vouchBox[0:0], vouch, nonce.N(), secretKey, &keys.Private) + + initBox := make([]byte, 128+len(meta)) + copy(initBox, keys.Public[:]) + copy(initBox[32:48], nonce[8:]) + copy(initBox[48:128], vouchBox) + copy(initBox[128:], meta) + nonce.Short("CurveZMQINITIATE", 2) + box.Seal(initiateBody[104:104], initBox, nonce.N(), secretKey, &ephemeral.Private) + return conn.SendCmd(zmq4.CmdInitiate, initiateBody) +} + +func (sec *security) doReady(conn *zmq4.Conn, nonce *Nonce, secretKey *[32]byte, ephemeral *KeyPair) ([]byte, error) { + cmd, err := conn.RecvCmd() + if err != nil { + return nil, fmt.Errorf("security/curve: could not receive READY from server: %w", err) + } + if cmd.Name != zmq4.CmdReady { + return nil, fmt.Errorf("security/curve: expected READY command, got %s", cmd.Name) + } + if len(cmd.Body) < 24 { + return nil, fmt.Errorf("security/curve: expected READY body to be at least 24 bytes long") + } + + servNonce := binary.BigEndian.Uint64(cmd.Body[:8]) + if servNonce != 1 { + return nil, fmt.Errorf("security/curve: expected server nonce to be 1, got %d", servNonce) + } + nonce.Short("CurveZMQREADY---", 1) + servMeta := make([]byte, len(cmd.Body)-24) + if _, ok := box.Open(servMeta[0:0], cmd.Body[8:], nonce.N(), secretKey, &ephemeral.Private); !ok { + return nil, fmt.Errorf("security/curve: failed opening metadata") + } + return servMeta, nil +} + +func (sec *security) doServerHello(nonce *Nonce, conn *zmq4.Conn) (clientTransPubKey [32]byte, err error) { + cmd, err := conn.RecvCmd() + if err != nil { + return clientTransPubKey, err + } + if cmd.Name != "HELLO" { + err = fmt.Errorf("security/curve: invalid handshake: expected hello, got %s", cmd.Name) + return + } + if len(cmd.Body) != 194 { + err = fmt.Errorf("security/curve: invalid hello: expected length to be 194 bytes, got %d", len(cmd.Body)) + return + } + if cmd.Body[0] != 1 || cmd.Body[1] != 0 { + err = fmt.Errorf("security/curve: Expected CURVEZMQ version 1.0, got %d.%d", cmd.Body[0], cmd.Body[1]) + return + } + + copy(clientTransPubKey[:], cmd.Body[74:106]) + cliNonceIdx := binary.BigEndian.Uint64(cmd.Body[106:114]) + nonce.Short("CurveZMQHELLO---", cliNonceIdx) + if cliNonceIdx != 1 { + err = fmt.Errorf("security/curve: Expected client nonce to be 1, got %d", cliNonceIdx) + return + } + var out [64]byte + + keys, err := sec.keyFunc(&clientTransPubKey) + if err != nil { + return clientTransPubKey, fmt.Errorf("security/curve: hello could not retrieve keypair: %w", err) + } + + _, ok := box.Open(out[0:0], cmd.Body[114:], nonce.N(), &clientTransPubKey, &keys.Private) + if !ok { + err = fmt.Errorf("security/curve: Invalid signature in hello command") + return + } + + for idx, byte := range out { + if byte != 0 { + err = fmt.Errorf("security/curve: Expected signature to contain only 0's, byte %d has value %x", idx, byte) + return + } + } + return +} + +func (sec *security) doServerWelcome(nonce *Nonce, conn *zmq4.Conn, clientTransPubKey, cookieKey *[32]byte, kp *KeyPair) error { + welcomeBody := make([]byte, 160) + var cookie [64]byte + copy(cookie[:], clientTransPubKey[:]) + copy(cookie[32:], kp.Private[:]) + PopulateSecKey(cookieKey) + + nonce.Long("COOKIE--") + cookieData := make([]byte, 96) + secretbox.Seal(cookieData[16:16], cookie[:], nonce.N(), cookieKey) + copy(cookieData[:16], nonce[8:]) + + welcomeBox := make([]byte, 128) + copy(welcomeBox, kp.Public[:]) + copy(welcomeBox[32:], cookieData) + nonce.Long("WELCOME-") + copy(welcomeBody, nonce[8:]) + keys, err := sec.keyFunc(clientTransPubKey) + if err != nil { + return fmt.Errorf("security/curve: welcome could not retrieve keypair: %w", err) + } + box.Seal(welcomeBody[16:16], welcomeBox, nonce.N(), clientTransPubKey, &keys.Private) + return conn.SendCmd(zmq4.CmdWelcome, welcomeBody) +} + +func PopulateSecKey(sec *[32]byte) { + _, err := io.ReadFull(rand.Reader, sec[:]) + if err != nil { + panic(err) + } +} + +func (sec *security) doServerInitiate(nonce *Nonce, conn *zmq4.Conn, cookieKey, clientTransPubKey, serverTransSecKey *[32]byte) ([]byte, error) { + cmd, err := conn.RecvCmd() + if err != nil { + return nil, fmt.Errorf("security/curve: could not receive INITIATE from server: %w", err) + } + if cmd.Name != "INITIATE" { + return nil, fmt.Errorf("security/curve: invalid handshake: expected initiate, got %s", cmd.Name) + } + if len(cmd.Body) < 248 { + return nil, fmt.Errorf("security/curve: invalid initiate: expected length to be at least 248 bytes, got %d", len(cmd.Body)) + } + + nonce.FromLong("COOKIE--", cmd.Body[:16]) + clientCookieBox := cmd.Body[16:96] + clientCookieData := make([]byte, 0, 64) + clientCookieData, ok := secretbox.Open(clientCookieData, clientCookieBox, nonce.N(), cookieKey) + if !ok { + return nil, fmt.Errorf("Client sent invalid cookie") + } + + var serverTransPubKey [32]byte + copy(serverTransSecKey[:], clientCookieData[32:]) + curve25519.ScalarBaseMult(&serverTransPubKey, serverTransSecKey) + + // second point to check client short nonce + cliNonceIdx := binary.BigEndian.Uint64(cmd.Body[96:104]) + if cliNonceIdx != 2 { + return nil, fmt.Errorf("Expected client nonce to be 2, got %d", cliNonceIdx) + } + nonce.Short("CurveZMQINITIATE", cliNonceIdx) + initBox := make([]byte, 0, len(cmd.Body)-120) + initBox, ok = box.Open(initBox, cmd.Body[104:], nonce.N(), clientTransPubKey, serverTransSecKey) + if !ok { + return nil, fmt.Errorf("Failed opening initiate box") + } + + var clientPermPublicKey [32]byte + copy(clientPermPublicKey[:], initBox[:32]) + vouch := initBox[32:128] + clientMeta := initBox[128:] + nonce.FromLong("VOUCH---", vouch[:16]) + vouchData := make([]byte, 0, 64) + vouchData, ok = box.Open(vouchData, vouch[16:], nonce.N(), &clientPermPublicKey, serverTransSecKey) + if !ok { + return nil, fmt.Errorf("Failed opening vouch box") + } + return clientMeta, nil +} + +func (sec *security) doServerReady(conn *zmq4.Conn, clientTransPubKey *[32]byte, + serverTransSecKey *[32]byte) error { + var nonce Nonce + nonce.Short("CurveZMQREADY---", 1) + + meta, err := conn.Meta.MarshalZMTP() + if err != nil { + return fmt.Errorf("security/curve: could not marshal metadata: %w", err) + } + + readyBody := make([]byte, len(meta)+16+8) + binary.BigEndian.PutUint64(readyBody[0:8], 1) + box.Seal(readyBody[8:8], meta, nonce.N(), clientTransPubKey, serverTransSecKey) + + return conn.SendCmd(zmq4.CmdReady, readyBody) +} + +var ( + _ zmq4.SecurityEncryption = (*security)(nil) +) diff --git a/security/curve/nonce.go b/security/curve/nonce.go new file mode 100644 index 0000000..05e1ec6 --- /dev/null +++ b/security/curve/nonce.go @@ -0,0 +1,38 @@ +package curve + +import ( + "crypto/rand" + "encoding/binary" + "fmt" + "unsafe" +) + +type Nonce [24]byte + +func (n *Nonce) Short(prefix string, nonce uint64) { + prefixBytes := unsafe.Slice(unsafe.StringData(prefix), 16) + asSlice := unsafe.Slice((*byte)(unsafe.Pointer(n)), 24) + copy(asSlice[:16], prefixBytes) + //binary.BigEndian.AppendUint64(asSlice[16:16], nonce) + binary.BigEndian.PutUint64(asSlice[16:], nonce) +} + +func (n *Nonce) Long(prefix string) { + prefixBytes := unsafe.Slice(unsafe.StringData(prefix), 8) + asSlice := unsafe.Slice((*byte)(unsafe.Pointer(n)), 24) + copy(asSlice[:8], prefixBytes) + if _, err := rand.Reader.Read(asSlice[8:]); err != nil { + panic(fmt.Errorf("Failed creating long nonce: %w", err)) + } +} + +func (n *Nonce) FromLong(prefix string, long []byte) { + prefixBytes := unsafe.Slice(unsafe.StringData(prefix), 8) + asSlice := unsafe.Slice((*byte)(unsafe.Pointer(n)), 24) + copy(asSlice[:8], prefixBytes) + copy(asSlice[8:], long) +} + +func (n *Nonce) N() *[24]byte { + return (*[24]byte)(n) +} diff --git a/security/curve/z85.go b/security/curve/z85.go new file mode 100644 index 0000000..145860e --- /dev/null +++ b/security/curve/z85.go @@ -0,0 +1,124 @@ +// Copyright 2023 The go-zeromq Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package curve + +import ( + "errors" + "fmt" + "math/big" +) + +// Z85Encoder provides the Z85 encoding for binary data. +// Z85 is a base-85 encoding designed for compactness and readability, used in ZeroMQ. +// The spec of Z85 is here: http://rfc.zeromq.org/spec:32/Z85/ +// +// Z85 only encodes data of a length divisible by 4. The encoded output +// length will be 5/4 of the input length. + +// Z85 encoding alphabet +var z85Encoder = []byte("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.-:+=^!/*?&<>()[]{}@%$#") + +// Lookup table for Z85 decoding +var z85Decoder [256]byte + +func init() { + // Initialize the decoder lookup table + for i := range z85Decoder { + z85Decoder[i] = 0xff // Invalid character marker + } + for i, c := range z85Encoder { + z85Decoder[c] = byte(i) + } +} + +// Encode encodes src into Z85 encoding. +// The input length must be divisible by 4. +func Encode(src []byte) ([]byte, error) { + // Check that we have valid input + if len(src)%4 != 0 { + return nil, errors.New("z85: input length must be a multiple of 4 bytes") + } + + // Each 4 bytes of input becomes 5 bytes of output + encodedLen := len(src) * 5 / 4 + dst := make([]byte, encodedLen) + + // Process input in 4-byte chunks + for i, j := 0, 0; i < len(src); i += 4 { + // Convert 4 bytes to a 32-bit integer + value := uint32(src[i])<<24 | uint32(src[i+1])<<16 | uint32(src[i+2])<<8 | uint32(src[i+3]) + + // Encode the integer as 5 characters + for k := 4; k >= 0; k-- { + dst[j+k] = z85Encoder[value%85] + value /= 85 + } + j += 5 + } + + return dst, nil +} + +// Decode decodes Z85-encoded data. +// The input length must be divisible by 5. +func Decode(src []byte) ([]byte, error) { + // Check that we have valid input + if len(src)%5 != 0 { + return nil, errors.New("z85: encoded length must be a multiple of 5 bytes") + } + + // Each 5 bytes of input becomes 4 bytes of output + decodedLen := len(src) * 4 / 5 + dst := make([]byte, decodedLen) + + // Process input in 5-byte chunks + for i, j := 0, 0; i < len(src); i += 5 { + // Accumulate value in base 85 + value := new(big.Int) + base := big.NewInt(85) + + for k := 0; k < 5; k++ { + // Check for invalid characters + if src[i+k] >= 128 || z85Decoder[src[i+k]] == 0xff { + return nil, fmt.Errorf("z85: invalid character '%c' at position %d", src[i+k], i+k) + } + + digit := big.NewInt(int64(z85Decoder[src[i+k]])) + value.Mul(value, base) + value.Add(value, digit) + } + + // Convert 32-bit integer to 4 bytes + valueBytes := value.Bytes() + // Ensure we have 4 bytes (pad with leading zeros if needed) + padLen := 4 - len(valueBytes) + if padLen > 0 { + for k := 0; k < padLen; k++ { + dst[j+k] = 0 + } + copy(dst[j+padLen:j+4], valueBytes) + } else { + copy(dst[j:j+4], valueBytes[len(valueBytes)-4:]) + } + + j += 4 + } + + return dst, nil +} + +// EncodeString encodes a byte slice to a Z85 string. +func EncodeString(src []byte) (string, error) { + encoded, err := Encode(src) + if err != nil { + return "", err + } + return string(encoded), nil +} + +// DecodeString decodes a Z85 string to a byte slice. +func DecodeString(src string) ([]byte, error) { + return Decode([]byte(src)) +} diff --git a/security/curve/z85_test.go b/security/curve/z85_test.go new file mode 100644 index 0000000..df0b9a6 --- /dev/null +++ b/security/curve/z85_test.go @@ -0,0 +1,156 @@ +// Copyright 2023 The go-zeromq Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package curve + +import ( + "bytes" + "encoding/hex" + "testing" +) + +func TestZ85EncodeDecode(t *testing.T) { + + decodeString, _ := hex.DecodeString("9493c171319a11c5469db7e81bae204768efa826b9a2f144ff4a581cdb3eed4b") + testCases := []struct { + name string + input []byte + encoded string + }{ + { + name: "HelloWorld", + input: []byte{0x86, 0x4F, 0xD2, 0x6F, 0xB5, 0x59, 0xF7, 0x5B}, + encoded: "HelloWorld", + }, + { + name: "EmptyString", + input: []byte{0, 0, 0, 0}, + encoded: "00000", + }, + { + name: "Binary", + input: decodeString, + encoded: "L-@}6f}5Y[mXd3L8)gqNxZ.(DXUwQZ%4lxk*DL0$", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Test encoding + enc, err := Encode(tc.input) + if err != nil { + t.Fatalf("Encode failed: %v", err) + } + + if string(enc) != tc.encoded { + t.Errorf("Encode result mismatch: got %q, want %q", string(enc), tc.encoded) + } + + // Test decoding + dec, err := Decode([]byte(tc.encoded)) + if err != nil { + t.Fatalf("Decode failed: %v", err) + } + + if !bytes.Equal(dec, tc.input) { + t.Errorf("Decode result mismatch: got %v, want %v", dec, tc.input) + } + + // Test EncodeString and DecodeString convenience functions + encStr, err := EncodeString(tc.input) + if err != nil { + t.Fatalf("EncodeString failed: %v", err) + } + + if encStr != tc.encoded { + t.Errorf("EncodeString result mismatch: got %q, want %q", encStr, tc.encoded) + } + + decBytes, err := DecodeString(tc.encoded) + if err != nil { + t.Fatalf("DecodeString failed: %v", err) + } + + if !bytes.Equal(decBytes, tc.input) { + t.Errorf("DecodeString result mismatch: got %v, want %v", decBytes, tc.input) + } + }) + } +} + +func TestZ85InvalidInput(t *testing.T) { + // Test invalid input lengths for encoding + _, err := Encode([]byte{1, 2, 3}) // Not divisible by 4 + if err == nil { + t.Error("Expected error for input length not divisible by 4, got nil") + } + + // Test invalid input lengths for decoding + _, err = Decode([]byte{1, 2, 3, 4}) // Not divisible by 5 + if err == nil { + t.Error("Expected error for input length not divisible by 5, got nil") + } + + // Test invalid characters for decoding + _, err = Decode([]byte("Hello~World")) // ~ is not in the Z85 alphabet + if err == nil { + t.Error("Expected error for invalid character, got nil") + } +} + +// Test the examples from the Z85 spec: http://rfc.zeromq.org/spec:32/Z85/ +func TestZ85SpecExamples(t *testing.T) { + // Example 1: a 32-byte CURVE key encoded with Z85 + key := []byte{ + 0x8E, 0x0B, 0xDD, 0x69, 0x76, 0x28, 0xB9, 0x1D, + 0x8F, 0x24, 0x55, 0x87, 0xEE, 0x95, 0xC5, 0xB0, + 0x4D, 0x48, 0x96, 0x3F, 0x79, 0x25, 0x98, 0x77, + 0xB4, 0x9C, 0xD9, 0x06, 0x3A, 0xEA, 0xD3, 0xB7, + } + + expected := "JTKVSB%%)wK0E.X)V>+}o?pNmC{O&4W4b!Ni{Lh6" + + encoded, err := EncodeString(key) + if err != nil { + t.Fatalf("EncodeString failed: %v", err) + } + + if encoded != expected { + t.Errorf("Z85 spec example encoding mismatch: got %q, want %q", encoded, expected) + } + + decoded, err := DecodeString(expected) + if err != nil { + t.Fatalf("DecodeString failed: %v", err) + } + + if !bytes.Equal(decoded, key) { + t.Errorf("Z85 spec example decoding mismatch: got %v, want %v", decoded, key) + } +} + +func BenchmarkZ85Encode(b *testing.B) { + data := bytes.Repeat([]byte{1, 2, 3, 4}, 256) // 1KB + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := Encode(data) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkZ85Decode(b *testing.B) { + data := bytes.Repeat([]byte{1, 2, 3, 4}, 256) // 1KB + encoded, _ := Encode(data) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := Decode(encoded) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/security/null/null.go b/security/null/null.go index 5d1cecc..5a4b826 100644 --- a/security/null/null.go +++ b/security/null/null.go @@ -61,7 +61,7 @@ func (security) Handshake(conn *zmq4.Conn, server bool) error { } // Encrypt writes the encrypted form of data to w. -func (security) Encrypt(w io.Writer, data []byte) (int, error) { +func (security) Encrypt(w io.Writer, data []byte, more bool) (int, error) { return w.Write(data) } diff --git a/security/null/null_test.go b/security/null/null_test.go index 4115a43..bcbe49b 100644 --- a/security/null/null_test.go +++ b/security/null/null_test.go @@ -5,7 +5,6 @@ package null_test import ( - "bytes" "context" "fmt" "os" @@ -24,25 +23,6 @@ func TestSecurity(t *testing.T) { if got, want := sec.Type(), zmq4.NullSecurity; got != want { t.Fatalf("got=%v, want=%v", got, want) } - - data := []byte("hello world") - wenc := new(bytes.Buffer) - if _, err := sec.Encrypt(wenc, data); err != nil { - t.Fatalf("error encrypting data: %+v", err) - } - - if !bytes.Equal(wenc.Bytes(), data) { - t.Fatalf("error encrypted data.\ngot = %q\nwant= %q\n", wenc.Bytes(), data) - } - - wdec := new(bytes.Buffer) - if _, err := sec.Decrypt(wdec, wenc.Bytes()); err != nil { - t.Fatalf("error decrypting data: %+v", err) - } - - if !bytes.Equal(wdec.Bytes(), data) { - t.Fatalf("error decrypted data.\ngot = %q\nwant= %q\n", wdec.Bytes(), data) - } } func TestHandshakeReqRep(t *testing.T) { diff --git a/security/plain/plain.go b/security/plain/plain.go index 2d9945b..6fb5e01 100644 --- a/security/plain/plain.go +++ b/security/plain/plain.go @@ -134,7 +134,7 @@ func (sec *security) Handshake(conn *zmq4.Conn, server bool) error { } // Encrypt writes the encrypted form of data to w. -func (security) Encrypt(w io.Writer, data []byte) (int, error) { +func (security) Encrypt(w io.Writer, data []byte, more bool) (int, error) { return w.Write(data) } diff --git a/security/plain/plain_test.go b/security/plain/plain_test.go index 15004ae..267e0b7 100644 --- a/security/plain/plain_test.go +++ b/security/plain/plain_test.go @@ -5,7 +5,6 @@ package plain_test import ( - "bytes" "context" "crypto/rand" "fmt" @@ -29,25 +28,6 @@ func TestSecurity(t *testing.T) { if got, want := sec.Type(), zmq4.PlainSecurity; got != want { t.Fatalf("got=%v, want=%v", got, want) } - - data := []byte("hello world") - wenc := new(bytes.Buffer) - if _, err := sec.Encrypt(wenc, data); err != nil { - t.Fatalf("error encrypting data: %+v", err) - } - - if !bytes.Equal(wenc.Bytes(), data) { - t.Fatalf("error encrypted data.\ngot = %q\nwant= %q\n", wenc.Bytes(), data) - } - - wdec := new(bytes.Buffer) - if _, err := sec.Decrypt(wdec, wenc.Bytes()); err != nil { - t.Fatalf("error decrypting data: %+v", err) - } - - if !bytes.Equal(wdec.Bytes(), data) { - t.Fatalf("error decrypted data.\ngot = %q\nwant= %q\n", wdec.Bytes(), data) - } } func TestHandshakeReqRep(t *testing.T) { diff --git a/security_test.go b/security_test.go index f293b1d..9ec2df0 100644 --- a/security_test.go +++ b/security_test.go @@ -5,7 +5,6 @@ package zmq4 import ( - "bytes" "context" "fmt" "os" @@ -22,25 +21,6 @@ func TestNullSecurity(t *testing.T) { if got, want := sec.Type(), NullSecurity; got != want { t.Fatalf("got=%v, want=%v", got, want) } - - data := []byte("hello world") - wenc := new(bytes.Buffer) - if _, err := sec.Encrypt(wenc, data); err != nil { - t.Fatalf("error encrypting data: %+v", err) - } - - if !bytes.Equal(wenc.Bytes(), data) { - t.Fatalf("error encrypted data.\ngot = %q\nwant= %q\n", wenc.Bytes(), data) - } - - wdec := new(bytes.Buffer) - if _, err := sec.Decrypt(wdec, wenc.Bytes()); err != nil { - t.Fatalf("error decrypting data: %+v", err) - } - - if !bytes.Equal(wdec.Bytes(), data) { - t.Fatalf("error decrypted data.\ngot = %q\nwant= %q\n", wdec.Bytes(), data) - } } func TestNullHandshakeReqRep(t *testing.T) {