diff --git a/cns/configuration/cns_config.json b/cns/configuration/cns_config.json index 81ef6c9b05..967089d142 100644 --- a/cns/configuration/cns_config.json +++ b/cns/configuration/cns_config.json @@ -35,5 +35,6 @@ "AZRSettings": { "PopulateHomeAzCacheRetryIntervalSecs": 60 }, - "MinTLSVersion": "TLS 1.2" + "MinTLSVersion": "TLS 1.2", + "MtlsClientCertSubjectName": "" } diff --git a/cns/configuration/configuration.go b/cns/configuration/configuration.go index 9ec5f8664f..b5fc0e4114 100644 --- a/cns/configuration/configuration.go +++ b/cns/configuration/configuration.go @@ -59,6 +59,7 @@ type CNSConfig struct { WireserverIP string GRPCSettings GRPCSettings MinTLSVersion string + MtlsClientCertSubjectName string } type TelemetrySettings struct { diff --git a/cns/configuration/configuration_test.go b/cns/configuration/configuration_test.go index 186c92c376..ab3d93ebd1 100644 --- a/cns/configuration/configuration_test.go +++ b/cns/configuration/configuration_test.go @@ -222,7 +222,8 @@ func TestSetCNSConfigDefaults(t *testing.T) { IPAddress: "localhost", Port: 8080, }, - MinTLSVersion: "TLS 1.2", + MinTLSVersion: "TLS 1.2", + MtlsClientCertSubjectName: "", }, }, { @@ -253,7 +254,8 @@ func TestSetCNSConfigDefaults(t *testing.T) { IPAddress: "192.168.1.1", Port: 9090, }, - MinTLSVersion: "TLS 1.3", + MinTLSVersion: "TLS 1.3", + MtlsClientCertSubjectName: "example.com", }, want: CNSConfig{ ChannelMode: "Other", @@ -283,7 +285,8 @@ func TestSetCNSConfigDefaults(t *testing.T) { IPAddress: "192.168.1.1", Port: 9090, }, - MinTLSVersion: "TLS 1.3", + MinTLSVersion: "TLS 1.3", + MtlsClientCertSubjectName: "example.com", }, }, } diff --git a/cns/service.go b/cns/service.go index ab7a0be3c3..5ee249c9fc 100644 --- a/cns/service.go +++ b/cns/service.go @@ -156,6 +156,54 @@ func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls. return nil, errors.Errorf("invalid tls settings: %+v", tlsSettings) } +// verifyPeerCertificate verifies the client certificate's subject name matches the expected subject name. +func verifyPeerCertificate(verifiedChains [][]*x509.Certificate, clientSubjectName string) error { + // no client subject name provided, skip verification + if clientSubjectName == "" { + return nil + } + + if len(verifiedChains) == 0 || len(verifiedChains[0]) == 0 { + return errors.New("no client certificate provided during mTLS") + } + + // Get client leaf certificate + clientCert := verifiedChains[0][0] + // Match DNS names (case-insensitive) + dnsNames := clientCert.DNSNames + for _, dns := range dnsNames { + if strings.EqualFold(dns, clientSubjectName) { + return nil + } + } + + // If SANs didn't match, fall back to Common Name (CN) match. + clientCN := clientCert.Subject.CommonName + if clientCN != "" && strings.EqualFold(clientCN, clientSubjectName) { + return nil + } + + // maskHalf of the DNS names + maskedDNS := make([]string, len(dnsNames)) + for i, dns := range dnsNames { + maskedDNS[i] = maskHalf(dns) + } + + return errors.Errorf("Failed to verify client certificate subject name during mTLS, clientSubjectName: %s, client cert SANs: %+v, clientCN: %s", + clientSubjectName, maskedDNS, maskHalf(clientCN)) +} + +// maskHalf masks half of the input string with asterisks. +func maskHalf(s string) string { + n := len(s) + if n == 0 { + return s + } + + half := n / 2 + return s[:half] + strings.Repeat("*", n-half) +} + func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) { tlsCertRetriever, err := localtls.GetTlsCertificateRetriever(tlsSettings) if err != nil { @@ -202,8 +250,10 @@ func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert tlsConfig.ClientCAs = rootCAs tlsConfig.RootCAs = rootCAs + tlsConfig.VerifyPeerCertificate = func(_ [][]byte, verifiedChains [][]*x509.Certificate) error { + return verifyPeerCertificate(verifiedChains, tlsSettings.MtlsClientCertSubjectName) + } } - logger.Debugf("TLS configured successfully from file: %+v", tlsSettings) return tlsConfig, nil @@ -254,6 +304,9 @@ func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- e tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert tlsConfig.ClientCAs = rootCAs tlsConfig.RootCAs = rootCAs + tlsConfig.VerifyPeerCertificate = func(_ [][]byte, verifiedChains [][]*x509.Certificate) error { + return verifyPeerCertificate(verifiedChains, tlsSettings.MtlsClientCertSubjectName) + } } logger.Debugf("TLS configured successfully from KV: %+v", tlsSettings) diff --git a/cns/service/main.go b/cns/service/main.go index 67f7872f44..d7b9a526d5 100644 --- a/cns/service/main.go +++ b/cns/service/main.go @@ -810,6 +810,7 @@ func main() { KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour, UseMTLS: cnsconfig.UseMTLS, MinTLSVersion: cnsconfig.MinTLSVersion, + MtlsClientCertSubjectName: cnsconfig.MtlsClientCertSubjectName, } } diff --git a/cns/service_test.go b/cns/service_test.go index d20c2ef11a..fd0b7a44b4 100644 --- a/cns/service_test.go +++ b/cns/service_test.go @@ -133,57 +133,108 @@ func TestNewService(t *testing.T) { t.Run("NewServiceWithMutualTLS", func(t *testing.T) { testCertFilePath := createTestCertificate(t) - config.TLSSettings = serverTLS.TlsSettings{ - TLSPort: "10091", - TLSSubjectName: "localhost", - TLSCertificatePath: testCertFilePath, - UseMTLS: true, - MinTLSVersion: "TLS 1.2", + cases := []struct { + name string + tlsSettings serverTLS.TlsSettings + handshakeFailureExpected bool + }{ + { + name: "matching client SANs", + tlsSettings: serverTLS.TlsSettings{ + TLSPort: "10091", + TLSSubjectName: "localhost", + TLSCertificatePath: testCertFilePath, + UseMTLS: true, + MinTLSVersion: "TLS 1.2", + MtlsClientCertSubjectName: "example.com", + }, + handshakeFailureExpected: false, + }, + { + name: "matching client cert CN", + tlsSettings: serverTLS.TlsSettings{ + TLSPort: "10093", + TLSSubjectName: "localhost", + TLSCertificatePath: testCertFilePath, + UseMTLS: true, + MinTLSVersion: "TLS 1.2", + MtlsClientCertSubjectName: "foo.com", // Common Name from test certificate + }, + handshakeFailureExpected: false, + }, + { + name: "failing to match client SANs and CN", + tlsSettings: serverTLS.TlsSettings{ + TLSPort: "10092", + TLSSubjectName: "localhost", + TLSCertificatePath: testCertFilePath, + UseMTLS: true, + MinTLSVersion: "TLS 1.2", + MtlsClientCertSubjectName: "random.com", + }, + handshakeFailureExpected: true, + }, } - svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store) - require.NoError(t, err) - require.IsType(t, &Service{}, svc) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + config.TLSSettings = tc.tlsSettings - svc.SetOption(acn.OptCnsURL, "") - svc.SetOption(acn.OptCnsPort, "") + svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store) + require.NoError(t, err) + require.IsType(t, &Service{}, svc) - err = svc.Initialize(config) - t.Cleanup(func() { - svc.Uninitialize() - }) - require.NoError(t, err) + svc.SetOption(acn.OptCnsURL, "") + svc.SetOption(acn.OptCnsPort, "") - err = svc.StartListener(config) - require.NoError(t, err) + err = svc.Initialize(config) + require.NoError(t, err) - mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings) - require.NoError(t, err) + err = svc.StartListener(config) + require.NoError(t, err) - client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: mTLSConfig, - }, - } + mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings) + require.NoError(t, err) - // TLS listener - req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "https://localhost:10091", http.NoBody) - require.NoError(t, err) - resp, err := client.Do(req) - t.Cleanup(func() { - resp.Body.Close() - }) - require.NoError(t, err) + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: mTLSConfig, + }, + } - // HTTP listener - httpClient := &http.Client{} - req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody) - require.NoError(t, err) - resp, err = httpClient.Do(req) - t.Cleanup(func() { - resp.Body.Close() - }) - require.NoError(t, err) + tlsURL := "https://localhost:" + tc.tlsSettings.TLSPort + // TLS listener + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsURL, http.NoBody) + require.NoError(t, err) + resp, err := client.Do(req) + if tc.handshakeFailureExpected { + require.Error(t, err) + require.ErrorContains(t, err, "Failed to verify client certificate subject name during mTLS") + } else { + require.NoError(t, err) + t.Cleanup(func() { + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + }) + } + + // HTTP listener + httpClient := &http.Client{} + req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody) + require.NoError(t, err) + resp, err = httpClient.Do(req) + require.NoError(t, err) + t.Cleanup(func() { + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + }) + + // Cleanup + svc.Uninitialize() + }) + } }) } @@ -355,3 +406,28 @@ func TestTLSVersionNumber(t *testing.T) { require.NoError(t, err) }) } + +func TestMaskHalf(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {"empty", "", ""}, + {"one char string", "e", "*"}, + {"two chars string", "ex", "e*"}, + {"three chars string", "exa", "e**"}, + {"four chars string", "exam", "ex**"}, + {"five chars string", "examp", "ex***"}, + {"long string", "example.com", "examp******"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := maskHalf(tc.in) + if got != tc.want { + t.Fatalf("maskHalf(%s) = %s, want %s", tc.in, got, tc.want) + } + }) + } +} diff --git a/server/tls/tlscertificate_retriever.go b/server/tls/tlscertificate_retriever.go index a22a7336b7..b6a0d11099 100644 --- a/server/tls/tlscertificate_retriever.go +++ b/server/tls/tlscertificate_retriever.go @@ -15,6 +15,7 @@ type TlsSettings struct { KeyVaultCertificateRefreshInterval time.Duration UseMTLS bool MinTLSVersion string + MtlsClientCertSubjectName string } func GetTlsCertificateRetriever(settings TlsSettings) (TlsCertificateRetriever, error) {