Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions server/server/rpc/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@ type certLoader struct {
}

func (l *certLoader) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
stat, err := os.Stat(l.KeyFile)
stat, err := os.Stat(l.CertFile)
if err != nil {
l.lock.RLock()
existingCert := l.cachedCert
l.lock.RUnlock()

if existingCert == nil {
return nil, fmt.Errorf("statting tls key file: %w", err)
return nil, fmt.Errorf("statting tls cert file: %w", err)
}

log.Printf("unable to stat tls key file, returning cached cert which may expire: %s", err)
log.Printf("unable to stat tls cert file, returning cached cert which may expire: %s", err)
return existingCert, nil
}

Expand Down
123 changes: 84 additions & 39 deletions server/server/rpc/tls_cert_loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,57 +17,106 @@ import (
)

func TestCertLoader_ReloadsNewKeyPair(t *testing.T) {
// Use Go testing temporary directory for test files
dir := t.TempDir()
certPath := filepath.Join(dir, "cert.pem")
keyPath := filepath.Join(dir, "key.pem")
tests := []struct {
name string
reuseKey bool
}{
{
name: "regenerate both cert and key",
reuseKey: false,
},
{
name: "regenerate only cert with same key",
reuseKey: true,
},
}

// Write initial certificate and key
certPEM1, keyPEM1 := generateCertKeyPair(t, "initial")
assert.NoError(t, os.WriteFile(certPath, certPEM1, 0644))
assert.NoError(t, os.WriteFile(keyPath, keyPEM1, 0644))
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Use Go testing temporary directory for test files
dir := t.TempDir()
certPath := filepath.Join(dir, "cert.pem")
keyPath := filepath.Join(dir, "key.pem")

// Initialize the loader
loader := &certLoader{CertFile: certPath, KeyFile: keyPath}
// Write initial certificate and key
certPEM1, keyPEM1 := generateCertKeyPair(t, "initial")
assert.NoError(t, os.WriteFile(certPath, certPEM1, 0644))
assert.NoError(t, os.WriteFile(keyPath, keyPEM1, 0644))

// First load should read the initial pair
loaded1, err := loader.GetClientCertificate(nil)
assert.NoError(t, err)
assert.NotNil(t, loaded1)
// Initialize the loader
loader := &certLoader{CertFile: certPath, KeyFile: keyPath}

// Compare against a direct X509KeyPair parse
expect1, err := tls.X509KeyPair(certPEM1, keyPEM1)
assert.NoError(t, err)
assert.Equal(t, expect1.Certificate, loaded1.Certificate)
// First load should read the initial pair
loaded1, err := loader.GetClientCertificate(nil)
assert.NoError(t, err)
assert.NotNil(t, loaded1)

// Wait to ensure file modification time will differ
time.Sleep(500 * time.Millisecond)
// Compare against a direct X509KeyPair parse
expect1, err := tls.X509KeyPair(certPEM1, keyPEM1)
assert.NoError(t, err)
assert.Equal(t, expect1.Certificate, loaded1.Certificate)

// Overwrite with a new certificate and key
certPEM2, keyPEM2 := generateCertKeyPair(t, "updated")
assert.NoError(t, os.WriteFile(certPath, certPEM2, 0644))
assert.NoError(t, os.WriteFile(keyPath, keyPEM2, 0644))
// Wait to ensure file modification time will differ
time.Sleep(500 * time.Millisecond)

// Second load should pick up the updated pair
loaded2, err := loader.GetClientCertificate(nil)
assert.NoError(t, err)
assert.NotNil(t, loaded2)
var certPEM2, keyPEM2 []byte
if tt.reuseKey {
// Generate new certificate with the same key
certPEM2 = generateCertForKey(t, "updated", keyPEM1)
keyPEM2 = keyPEM1
} else {
// Generate new certificate and key
certPEM2, keyPEM2 = generateCertKeyPair(t, "updated")
}

expect2, err := tls.X509KeyPair(certPEM2, keyPEM2)
assert.NoError(t, err)
// Write the new certificate
assert.NoError(t, os.WriteFile(certPath, certPEM2, 0644))

// Compare the loaded certificate with the expected one
assert.Equal(t, expect2.Certificate, loaded2.Certificate)
// Write the new key if we are not reusing the existing key
if !tt.reuseKey {
assert.NoError(t, os.WriteFile(keyPath, keyPEM2, 0644))
}

// Ensure the loader did not return the old certificate
assert.NotEqual(t, expect1.Certificate, loaded2.Certificate)
// Second load should pick up the updated pair
loaded2, err := loader.GetClientCertificate(nil)
assert.NoError(t, err)
assert.NotNil(t, loaded2)

expect2, err := tls.X509KeyPair(certPEM2, keyPEM2)
assert.NoError(t, err)

// Compare the loaded certificate with the expected one
assert.Equal(t, expect2.Certificate, loaded2.Certificate)

// Compare the loaded private key with the expected one
assert.Equal(t, expect2.PrivateKey, loaded2.PrivateKey)

// Ensure the loader did not return the old certificate
assert.NotEqual(t, expect1.Certificate, loaded2.Certificate)
})
}
}

// generateCertKeyPair creates a self-signed certificate and private key for testing.
func generateCertKeyPair(t *testing.T, commonName string) (certPEM, keyPEM []byte) {
// Generate a private key
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
// PEM encode the key
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
// Generate certificate with this key
certPEM = generateCertForKey(t, commonName, keyPEM)
return certPEM, keyPEM
}

// generateCertForKey creates a self-signed certificate using an existing key for testing.
func generateCertForKey(t *testing.T, commonName string, existingKeyPEM []byte) []byte {
// Parse the existing key
block, _ := pem.Decode(existingKeyPEM)
assert.NotNil(t, block)
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
assert.NoError(t, err)

// Create a certificate template
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
Expand All @@ -85,9 +134,5 @@ func generateCertKeyPair(t *testing.T, commonName string) (certPEM, keyPEM []byt
// Self-sign the certificate
derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &key.PublicKey, key)
assert.NoError(t, err)
// PEM encode the certificate and key
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
return certPEM, keyPEM
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
}

Loading