diff --git a/server/server/rpc/tls.go b/server/server/rpc/tls.go index 382bb3b6ab..3e53e4f2ba 100644 --- a/server/server/rpc/tls.go +++ b/server/server/rpc/tls.go @@ -61,17 +61,17 @@ type certLoader struct { } func (l *certLoader) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { - stat, err := os.Stat(l.KeyFile) + stat, err := os.Stat(l.CertFile) if err != nil { l.lock.RLock() existingCert := l.cachedCert l.lock.RUnlock() if existingCert == nil { - return nil, fmt.Errorf("statting tls key file: %w", err) + return nil, fmt.Errorf("statting tls cert file: %w", err) } - log.Printf("unable to stat tls key file, returning cached cert which may expire: %s", err) + log.Printf("unable to stat tls cert file, returning cached cert which may expire: %s", err) return existingCert, nil } diff --git a/server/server/rpc/tls_cert_loader_test.go b/server/server/rpc/tls_cert_loader_test.go index 472d8daa72..37f0862383 100644 --- a/server/server/rpc/tls_cert_loader_test.go +++ b/server/server/rpc/tls_cert_loader_test.go @@ -17,50 +17,84 @@ import ( ) func TestCertLoader_ReloadsNewKeyPair(t *testing.T) { - // Use Go testing temporary directory for test files - dir := t.TempDir() - certPath := filepath.Join(dir, "cert.pem") - keyPath := filepath.Join(dir, "key.pem") + tests := []struct { + name string + reuseKey bool + }{ + { + name: "regenerate both cert and key", + reuseKey: false, + }, + { + name: "regenerate only cert with same key", + reuseKey: true, + }, + } - // Write initial certificate and key - certPEM1, keyPEM1 := generateCertKeyPair(t, "initial") - assert.NoError(t, os.WriteFile(certPath, certPEM1, 0644)) - assert.NoError(t, os.WriteFile(keyPath, keyPEM1, 0644)) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Use Go testing temporary directory for test files + dir := t.TempDir() + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") - // Initialize the loader - loader := &certLoader{CertFile: certPath, KeyFile: keyPath} + // Write initial certificate and key + certPEM1, keyPEM1 := generateCertKeyPair(t, "initial") + assert.NoError(t, os.WriteFile(certPath, certPEM1, 0644)) + assert.NoError(t, os.WriteFile(keyPath, keyPEM1, 0644)) - // First load should read the initial pair - loaded1, err := loader.GetClientCertificate(nil) - assert.NoError(t, err) - assert.NotNil(t, loaded1) + // Initialize the loader + loader := &certLoader{CertFile: certPath, KeyFile: keyPath} - // Compare against a direct X509KeyPair parse - expect1, err := tls.X509KeyPair(certPEM1, keyPEM1) - assert.NoError(t, err) - assert.Equal(t, expect1.Certificate, loaded1.Certificate) + // First load should read the initial pair + loaded1, err := loader.GetClientCertificate(nil) + assert.NoError(t, err) + assert.NotNil(t, loaded1) - // Wait to ensure file modification time will differ - time.Sleep(500 * time.Millisecond) + // Compare against a direct X509KeyPair parse + expect1, err := tls.X509KeyPair(certPEM1, keyPEM1) + assert.NoError(t, err) + assert.Equal(t, expect1.Certificate, loaded1.Certificate) - // Overwrite with a new certificate and key - certPEM2, keyPEM2 := generateCertKeyPair(t, "updated") - assert.NoError(t, os.WriteFile(certPath, certPEM2, 0644)) - assert.NoError(t, os.WriteFile(keyPath, keyPEM2, 0644)) + // Wait to ensure file modification time will differ + time.Sleep(500 * time.Millisecond) - // Second load should pick up the updated pair - loaded2, err := loader.GetClientCertificate(nil) - assert.NoError(t, err) - assert.NotNil(t, loaded2) + var certPEM2, keyPEM2 []byte + if tt.reuseKey { + // Generate new certificate with the same key + certPEM2 = generateCertForKey(t, "updated", keyPEM1) + keyPEM2 = keyPEM1 + } else { + // Generate new certificate and key + certPEM2, keyPEM2 = generateCertKeyPair(t, "updated") + } - expect2, err := tls.X509KeyPair(certPEM2, keyPEM2) - assert.NoError(t, err) + // Write the new certificate + assert.NoError(t, os.WriteFile(certPath, certPEM2, 0644)) - // Compare the loaded certificate with the expected one - assert.Equal(t, expect2.Certificate, loaded2.Certificate) + // Write the new key if we are not reusing the existing key + if !tt.reuseKey { + assert.NoError(t, os.WriteFile(keyPath, keyPEM2, 0644)) + } - // Ensure the loader did not return the old certificate - assert.NotEqual(t, expect1.Certificate, loaded2.Certificate) + // Second load should pick up the updated pair + loaded2, err := loader.GetClientCertificate(nil) + assert.NoError(t, err) + assert.NotNil(t, loaded2) + + expect2, err := tls.X509KeyPair(certPEM2, keyPEM2) + assert.NoError(t, err) + + // Compare the loaded certificate with the expected one + assert.Equal(t, expect2.Certificate, loaded2.Certificate) + + // Compare the loaded private key with the expected one + assert.Equal(t, expect2.PrivateKey, loaded2.PrivateKey) + + // Ensure the loader did not return the old certificate + assert.NotEqual(t, expect1.Certificate, loaded2.Certificate) + }) + } } // generateCertKeyPair creates a self-signed certificate and private key for testing. @@ -68,6 +102,21 @@ func generateCertKeyPair(t *testing.T, commonName string) (certPEM, keyPEM []byt // Generate a private key key, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) + // PEM encode the key + keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + // Generate certificate with this key + certPEM = generateCertForKey(t, commonName, keyPEM) + return certPEM, keyPEM +} + +// generateCertForKey creates a self-signed certificate using an existing key for testing. +func generateCertForKey(t *testing.T, commonName string, existingKeyPEM []byte) []byte { + // Parse the existing key + block, _ := pem.Decode(existingKeyPEM) + assert.NotNil(t, block) + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + assert.NoError(t, err) + // Create a certificate template serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) @@ -85,9 +134,5 @@ func generateCertKeyPair(t *testing.T, commonName string) (certPEM, keyPEM []byt // Self-sign the certificate derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &key.PublicKey, key) assert.NoError(t, err) - // PEM encode the certificate and key - certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) - return certPEM, keyPEM + return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) } -