Skip to content
3 changes: 2 additions & 1 deletion cns/configuration/cns_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@
"AZRSettings": {
"PopulateHomeAzCacheRetryIntervalSecs": 60
},
"MinTLSVersion": "TLS 1.2"
"MinTLSVersion": "TLS 1.2",
"MtlsClientCertSubjectName": ""
}
1 change: 1 addition & 0 deletions cns/configuration/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type CNSConfig struct {
WireserverIP string
GRPCSettings GRPCSettings
MinTLSVersion string
MtlsClientCertSubjectName string
}

type TelemetrySettings struct {
Expand Down
9 changes: 6 additions & 3 deletions cns/configuration/configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ func TestSetCNSConfigDefaults(t *testing.T) {
IPAddress: "localhost",
Port: 8080,
},
MinTLSVersion: "TLS 1.2",
MinTLSVersion: "TLS 1.2",
MtlsClientCertSubjectName: "",
},
},
{
Expand Down Expand Up @@ -253,7 +254,8 @@ func TestSetCNSConfigDefaults(t *testing.T) {
IPAddress: "192.168.1.1",
Port: 9090,
},
MinTLSVersion: "TLS 1.3",
MinTLSVersion: "TLS 1.3",
MtlsClientCertSubjectName: "example.com",
},
want: CNSConfig{
ChannelMode: "Other",
Expand Down Expand Up @@ -283,7 +285,8 @@ func TestSetCNSConfigDefaults(t *testing.T) {
IPAddress: "192.168.1.1",
Port: 9090,
},
MinTLSVersion: "TLS 1.3",
MinTLSVersion: "TLS 1.3",
MtlsClientCertSubjectName: "example.com",
},
},
}
Expand Down
55 changes: 54 additions & 1 deletion cns/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,54 @@ func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls.
return nil, errors.Errorf("invalid tls settings: %+v", tlsSettings)
}

// verifyPeerCertificate verifies the client certificate's subject name matches the expected subject name.
func verifyPeerCertificate(verifiedChains [][]*x509.Certificate, clientSubjectName string) error {
// no client subject name provided, skip verification
if clientSubjectName == "" {
return nil
}

if len(verifiedChains) == 0 || len(verifiedChains[0]) == 0 {
return errors.New("no client certificate provided during mTLS")
}

// Get client leaf certificate
clientCert := verifiedChains[0][0]
// Match DNS names (case-insensitive)
dnsNames := clientCert.DNSNames
for _, dns := range dnsNames {
if strings.EqualFold(dns, clientSubjectName) {
return nil
}
}

// If SANs didn't match, fall back to Common Name (CN) match.
clientCN := clientCert.Subject.CommonName
if clientCN != "" && strings.EqualFold(clientCN, clientSubjectName) {
return nil
}

// maskHalf of the DNS names
maskedDNS := make([]string, len(dnsNames))
for i, dns := range dnsNames {
maskedDNS[i] = maskHalf(dns)
}

return errors.Errorf("Failed to verify client certificate subject name during mTLS, clientSubjectName: %s, client cert SANs: %+v, clientCN: %s",
clientSubjectName, maskedDNS, maskHalf(clientCN))
}

// maskHalf masks half of the input string with asterisks.
func maskHalf(s string) string {
n := len(s)
if n == 0 {
return s
}

half := n / 2
return s[:half] + strings.Repeat("*", n-half)
}

func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) {
tlsCertRetriever, err := localtls.GetTlsCertificateRetriever(tlsSettings)
if err != nil {
Expand Down Expand Up @@ -202,8 +250,10 @@ func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error)
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = rootCAs
tlsConfig.RootCAs = rootCAs
tlsConfig.VerifyPeerCertificate = func(_ [][]byte, verifiedChains [][]*x509.Certificate) error {
return verifyPeerCertificate(verifiedChains, tlsSettings.MtlsClientCertSubjectName)
}
}

logger.Debugf("TLS configured successfully from file: %+v", tlsSettings)

return tlsConfig, nil
Expand Down Expand Up @@ -254,6 +304,9 @@ func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- e
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = rootCAs
tlsConfig.RootCAs = rootCAs
tlsConfig.VerifyPeerCertificate = func(_ [][]byte, verifiedChains [][]*x509.Certificate) error {
return verifyPeerCertificate(verifiedChains, tlsSettings.MtlsClientCertSubjectName)
}
}

logger.Debugf("TLS configured successfully from KV: %+v", tlsSettings)
Expand Down
1 change: 1 addition & 0 deletions cns/service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@ func main() {
KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour,
UseMTLS: cnsconfig.UseMTLS,
MinTLSVersion: cnsconfig.MinTLSVersion,
MtlsClientCertSubjectName: cnsconfig.MtlsClientCertSubjectName,
}
}

Expand Down
160 changes: 118 additions & 42 deletions cns/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,57 +133,108 @@ func TestNewService(t *testing.T) {
t.Run("NewServiceWithMutualTLS", func(t *testing.T) {
testCertFilePath := createTestCertificate(t)

config.TLSSettings = serverTLS.TlsSettings{
TLSPort: "10091",
TLSSubjectName: "localhost",
TLSCertificatePath: testCertFilePath,
UseMTLS: true,
MinTLSVersion: "TLS 1.2",
cases := []struct {
name string
tlsSettings serverTLS.TlsSettings
handshakeFailureExpected bool
}{
{
name: "matching client SANs",
tlsSettings: serverTLS.TlsSettings{
TLSPort: "10091",
TLSSubjectName: "localhost",
TLSCertificatePath: testCertFilePath,
UseMTLS: true,
MinTLSVersion: "TLS 1.2",
MtlsClientCertSubjectName: "example.com",
},
handshakeFailureExpected: false,
},
{
name: "matching client cert CN",
tlsSettings: serverTLS.TlsSettings{
TLSPort: "10093",
TLSSubjectName: "localhost",
TLSCertificatePath: testCertFilePath,
UseMTLS: true,
MinTLSVersion: "TLS 1.2",
MtlsClientCertSubjectName: "foo.com", // Common Name from test certificate
},
handshakeFailureExpected: false,
},
{
name: "failing to match client SANs and CN",
tlsSettings: serverTLS.TlsSettings{
TLSPort: "10092",
TLSSubjectName: "localhost",
TLSCertificatePath: testCertFilePath,
UseMTLS: true,
MinTLSVersion: "TLS 1.2",
MtlsClientCertSubjectName: "random.com",
},
handshakeFailureExpected: true,
},
}

svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
require.NoError(t, err)
require.IsType(t, &Service{}, svc)
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
config.TLSSettings = tc.tlsSettings

svc.SetOption(acn.OptCnsURL, "")
svc.SetOption(acn.OptCnsPort, "")
svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
require.NoError(t, err)
require.IsType(t, &Service{}, svc)

err = svc.Initialize(config)
t.Cleanup(func() {
svc.Uninitialize()
})
require.NoError(t, err)
svc.SetOption(acn.OptCnsURL, "")
svc.SetOption(acn.OptCnsPort, "")

err = svc.StartListener(config)
require.NoError(t, err)
err = svc.Initialize(config)
require.NoError(t, err)

mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings)
require.NoError(t, err)
err = svc.StartListener(config)
require.NoError(t, err)

client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: mTLSConfig,
},
}
mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings)
require.NoError(t, err)

// TLS listener
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "https://localhost:10091", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
t.Cleanup(func() {
resp.Body.Close()
})
require.NoError(t, err)
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: mTLSConfig,
},
}

// HTTP listener
httpClient := &http.Client{}
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
require.NoError(t, err)
resp, err = httpClient.Do(req)
t.Cleanup(func() {
resp.Body.Close()
})
require.NoError(t, err)
tlsURL := "https://localhost:" + tc.tlsSettings.TLSPort
// TLS listener
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsURL, http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
if tc.handshakeFailureExpected {
require.Error(t, err)
require.ErrorContains(t, err, "Failed to verify client certificate subject name during mTLS")
} else {
require.NoError(t, err)
t.Cleanup(func() {
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
})
}

// HTTP listener
httpClient := &http.Client{}
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
require.NoError(t, err)
resp, err = httpClient.Do(req)
require.NoError(t, err)
t.Cleanup(func() {
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
})

// Cleanup
svc.Uninitialize()
})
}
})
}

Expand Down Expand Up @@ -355,3 +406,28 @@ func TestTLSVersionNumber(t *testing.T) {
require.NoError(t, err)
})
}

func TestMaskHalf(t *testing.T) {
tests := []struct {
name string
in string
want string
}{
{"empty", "", ""},
{"one char string", "e", "*"},
{"two chars string", "ex", "e*"},
{"three chars string", "exa", "e**"},
{"four chars string", "exam", "ex**"},
{"five chars string", "examp", "ex***"},
{"long string", "example.com", "examp******"},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := maskHalf(tc.in)
if got != tc.want {
t.Fatalf("maskHalf(%s) = %s, want %s", tc.in, got, tc.want)
}
})
}
}
1 change: 1 addition & 0 deletions server/tls/tlscertificate_retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type TlsSettings struct {
KeyVaultCertificateRefreshInterval time.Duration
UseMTLS bool
MinTLSVersion string
MtlsClientCertSubjectName string
}

func GetTlsCertificateRetriever(settings TlsSettings) (TlsCertificateRetriever, error) {
Expand Down
Loading