Skip to content

Commit 6c262c1

Browse files
committed
Add curve preferences, pinned public key SHA256 and mTLS for TLS options
1 parent db686fd commit 6c262c1

File tree

7 files changed

+565
-126
lines changed

7 files changed

+565
-126
lines changed

common/tls/reality_server.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ func NewRealityServer(ctx context.Context, logger log.ContextLogger, options opt
6868
return nil, E.New("unknown cipher_suite: ", cipherSuite)
6969
}
7070
}
71-
if len(options.Certificate) > 0 || options.CertificatePath != "" {
71+
if len(options.CurvePreferences) > 0 {
72+
return nil, E.New("curve preferences is unavailable in reality")
73+
}
74+
if len(options.Certificate) > 0 || options.CertificatePath != "" || len(options.ClientCertificatePublicKeySHA256) > 0 {
7275
return nil, E.New("certificate is unavailable in reality")
7376
}
7477
if len(options.Key) > 0 || options.KeyPath != "" {

common/tls/std_client.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package tls
22

33
import (
4+
"bytes"
45
"context"
6+
"crypto/sha256"
57
"crypto/tls"
68
"crypto/x509"
9+
"encoding/base64"
710
"net"
811
"os"
912
"strings"
@@ -108,6 +111,15 @@ func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddres
108111
return err
109112
}
110113
}
114+
if len(options.CertificatePublicKeySHA256) > 0 {
115+
if len(options.Certificate) > 0 || options.CertificatePath != "" {
116+
return nil, E.New("certificate_public_key_sha256 is conflict with certificate or certificate_path")
117+
}
118+
tlsConfig.InsecureSkipVerify = true
119+
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
120+
return verifyPublicKeySHA256(options.CertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
121+
}
122+
}
111123
if len(options.ALPN) > 0 {
112124
tlsConfig.NextProtos = options.ALPN
113125
}
@@ -137,6 +149,9 @@ func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddres
137149
return nil, E.New("unknown cipher_suite: ", cipherSuite)
138150
}
139151
}
152+
for _, curve := range options.CurvePreferences {
153+
tlsConfig.CurvePreferences = append(tlsConfig.CurvePreferences, tls.CurveID(curve))
154+
}
140155
var certificate []byte
141156
if len(options.Certificate) > 0 {
142157
certificate = []byte(strings.Join(options.Certificate, "\n"))
@@ -175,3 +190,22 @@ func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddres
175190
}
176191
return config, nil
177192
}
193+
194+
func verifyPublicKeySHA256(knownHashValues [][]byte, rawCerts [][]byte, timeFunc func() time.Time) error {
195+
leafCertificate, err := x509.ParseCertificate(rawCerts[0])
196+
if err != nil {
197+
return E.Cause(err, "failed to parse leaf certificate")
198+
}
199+
200+
pubKeyBytes, err := x509.MarshalPKIXPublicKey(leafCertificate.PublicKey)
201+
if err != nil {
202+
return E.Cause(err, "failed to marshal public key")
203+
}
204+
hashValue := sha256.Sum256(pubKeyBytes)
205+
for _, value := range knownHashValues {
206+
if bytes.Equal(value, hashValue[:]) {
207+
return nil
208+
}
209+
}
210+
return E.New("unrecognized remote public key: ", base64.StdEncoding.EncodeToString(hashValue[:]))
211+
}

common/tls/std_server.go

Lines changed: 93 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package tls
33
import (
44
"context"
55
"crypto/tls"
6+
"crypto/x509"
67
"net"
78
"os"
89
"strings"
@@ -22,16 +23,17 @@ import (
2223
var errInsecureUnused = E.New("tls: insecure unused")
2324

2425
type STDServerConfig struct {
25-
access sync.RWMutex
26-
config *tls.Config
27-
logger log.Logger
28-
acmeService adapter.SimpleLifecycle
29-
certificate []byte
30-
key []byte
31-
certificatePath string
32-
keyPath string
33-
echKeyPath string
34-
watcher *fswatch.Watcher
26+
access sync.RWMutex
27+
config *tls.Config
28+
logger log.Logger
29+
acmeService adapter.SimpleLifecycle
30+
certificate []byte
31+
key []byte
32+
certificatePath string
33+
keyPath string
34+
clientCertificatePath []string
35+
echKeyPath string
36+
watcher *fswatch.Watcher
3537
}
3638

3739
func (c *STDServerConfig) ServerName() string {
@@ -111,6 +113,9 @@ func (c *STDServerConfig) startWatcher() error {
111113
if c.echKeyPath != "" {
112114
watchPath = append(watchPath, c.echKeyPath)
113115
}
116+
if len(c.clientCertificatePath) > 0 {
117+
watchPath = append(watchPath, c.clientCertificatePath...)
118+
}
114119
if len(watchPath) == 0 {
115120
return nil
116121
}
@@ -159,6 +164,30 @@ func (c *STDServerConfig) certificateUpdated(path string) error {
159164
c.config = config
160165
c.access.Unlock()
161166
c.logger.Info("reloaded TLS certificate")
167+
} else if common.Contains(c.clientCertificatePath, path) {
168+
clientCertificateCA := x509.NewCertPool()
169+
var reloaded bool
170+
for _, certPath := range c.clientCertificatePath {
171+
content, err := os.ReadFile(certPath)
172+
if err != nil {
173+
c.logger.Error(E.Cause(err, "reload certificate from ", c.clientCertificatePath))
174+
continue
175+
}
176+
if !clientCertificateCA.AppendCertsFromPEM(content) {
177+
c.logger.Error(E.New("invalid client certificate file: ", certPath))
178+
continue
179+
}
180+
reloaded = true
181+
}
182+
if !reloaded {
183+
return E.New("client certificates is empty")
184+
}
185+
c.access.Lock()
186+
config := c.config.Clone()
187+
config.ClientCAs = clientCertificateCA
188+
c.config = config
189+
c.access.Unlock()
190+
c.logger.Info("reloaded client certificates")
162191
} else if path == c.echKeyPath {
163192
echKey, err := os.ReadFile(c.echKeyPath)
164193
if err != nil {
@@ -235,8 +264,14 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
235264
return nil, E.New("unknown cipher_suite: ", cipherSuite)
236265
}
237266
}
238-
var certificate []byte
239-
var key []byte
267+
for _, curveID := range options.CurvePreferences {
268+
tlsConfig.CurvePreferences = append(tlsConfig.CurvePreferences, tls.CurveID(curveID))
269+
}
270+
tlsConfig.ClientAuth = tls.ClientAuthType(options.ClientAuthentication)
271+
var (
272+
certificate []byte
273+
key []byte
274+
)
240275
if acmeService == nil {
241276
if len(options.Certificate) > 0 {
242277
certificate = []byte(strings.Join(options.Certificate, "\n"))
@@ -278,6 +313,43 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
278313
tlsConfig.Certificates = []tls.Certificate{keyPair}
279314
}
280315
}
316+
if len(options.ClientCertificate) > 0 || len(options.ClientCertificatePath) > 0 {
317+
if tlsConfig.ClientAuth == tls.NoClientCert {
318+
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
319+
}
320+
}
321+
if tlsConfig.ClientAuth == tls.VerifyClientCertIfGiven || tlsConfig.ClientAuth == tls.RequireAndVerifyClientCert {
322+
if len(options.ClientCertificate) > 0 {
323+
clientCertificateCA := x509.NewCertPool()
324+
if !clientCertificateCA.AppendCertsFromPEM([]byte(strings.Join(options.ClientCertificate, "\n"))) {
325+
return nil, E.New("invalid client certificate strings")
326+
}
327+
tlsConfig.ClientCAs = clientCertificateCA
328+
} else if len(options.ClientCertificatePath) > 0 {
329+
clientCertificateCA := x509.NewCertPool()
330+
for _, path := range options.ClientCertificatePath {
331+
content, err := os.ReadFile(path)
332+
if err != nil {
333+
return nil, E.Cause(err, "read client certificate from ", path)
334+
}
335+
if !clientCertificateCA.AppendCertsFromPEM(content) {
336+
return nil, E.New("invalid client certificate file: ", path)
337+
}
338+
}
339+
tlsConfig.ClientCAs = clientCertificateCA
340+
} else if len(options.ClientCertificatePublicKeySHA256) > 0 {
341+
if tlsConfig.ClientAuth == tls.RequireAndVerifyClientCert {
342+
tlsConfig.ClientAuth = tls.RequireAnyClientCert
343+
} else if tlsConfig.ClientAuth == tls.VerifyClientCertIfGiven {
344+
tlsConfig.ClientAuth = tls.RequestClientCert
345+
}
346+
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
347+
return verifyPublicKeySHA256(options.ClientCertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
348+
}
349+
} else {
350+
return nil, E.New("missing client_certificate, client_certificate_path or client_certificate_public_key_sha256 for client authentication")
351+
}
352+
}
281353
var echKeyPath string
282354
if options.ECH != nil && options.ECH.Enabled {
283355
err = parseECHServerConfig(ctx, options, tlsConfig, &echKeyPath)
@@ -286,14 +358,15 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
286358
}
287359
}
288360
serverConfig := &STDServerConfig{
289-
config: tlsConfig,
290-
logger: logger,
291-
acmeService: acmeService,
292-
certificate: certificate,
293-
key: key,
294-
certificatePath: options.CertificatePath,
295-
keyPath: options.KeyPath,
296-
echKeyPath: echKeyPath,
361+
config: tlsConfig,
362+
logger: logger,
363+
acmeService: acmeService,
364+
certificate: certificate,
365+
key: key,
366+
certificatePath: options.CertificatePath,
367+
clientCertificatePath: options.ClientCertificatePath,
368+
keyPath: options.KeyPath,
369+
echKeyPath: echKeyPath,
297370
}
298371
serverConfig.config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
299372
serverConfig.access.Lock()

common/tls/utls_client.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,15 @@ func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddre
167167
}
168168
tlsConfig.InsecureServerNameToVerify = serverName
169169
}
170+
if len(options.CertificatePublicKeySHA256) > 0 {
171+
if len(options.Certificate) > 0 || options.CertificatePath != "" {
172+
return nil, E.New("certificate_public_key_sha256 is conflict with certificate or certificate_path")
173+
}
174+
tlsConfig.InsecureSkipVerify = true
175+
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
176+
return verifyPublicKeySHA256(options.CertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
177+
}
178+
}
170179
if len(options.ALPN) > 0 {
171180
tlsConfig.NextProtos = options.ALPN
172181
}

0 commit comments

Comments
 (0)