Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion internal/provider/kubernetes/secrets.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
package kubernetes

import (
"bytes"
"context"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"maps"
"reflect"
"time"

corev1 "k8s.io/api/core/v1"
kerrors "k8s.io/apimachinery/pkg/api/errors"
Expand Down Expand Up @@ -114,7 +119,21 @@ func CreateOrUpdateSecrets(ctx context.Context, client client.Client, secrets []
existingSecrets = append(existingSecrets, fmt.Sprintf("%s/%s", secret.Namespace, secret.Name))
continue
}
fmt.Println()

// Bundle the old CA with the new CA for backwards-compatible rotation.
// Pods pick up updated secrets at different times, so during the rotation
// window some pods may still present leaf certs signed by the old CA.
// Keeping both CAs in the trust bundle prevents mTLS failures while all
// components converge to the new certificates.
if oldCA := current.Data[caCertificateKey]; len(oldCA) > 0 {
if newCA := secret.Data[caCertificateKey]; len(newCA) > 0 {
if bundled := bundleCACerts(newCA, oldCA); !bytes.Equal(bundled, newCA) {
newData := maps.Clone(secret.Data)
newData[caCertificateKey] = bundled
secret.Data = newData
}
}
}

if !reflect.DeepEqual(secret.Data, current.Data) {
if err := client.Update(ctx, &secret); err != nil {
Expand All @@ -133,3 +152,58 @@ func CreateOrUpdateSecrets(ctx context.Context, client client.Client, secrets []

return tidySecrets, nil
}

// bundleCACerts returns a PEM bundle containing all certificates from newCA
// followed by any non-expired, non-duplicate certificates from oldCA. This
// allows components that haven't yet reloaded to continue trusting leaf certs
// signed by the previous CA while simultaneously trusting the new CA.
func bundleCACerts(newCA, oldCA []byte) []byte {
if bytes.Equal(newCA, oldCA) {
return newCA
}

// Index the certs already present in newCA by their raw DER bytes.
existing := make(map[string]struct{})
for rest := newCA; ; {
var block *pem.Block
block, rest = pem.Decode(rest)
if block == nil {
break
}
if block.Type == "CERTIFICATE" {
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
existing[string(cert.Raw)] = struct{}{}
}
}
}

// Append only the first non-expired, non-duplicate cert from oldCA.
// This is always the CA that was active at the last rotation. Carrying
// forward only one previous CA keeps the bundle at a maximum of two
// entries regardless of rotation frequency: by the time a second rotation
// occurs all components should have converged on the previous rotation's
// certs, so earlier CAs are no longer needed.
result := make([]byte, len(newCA))
copy(result, newCA)
now := time.Now()
for rest := oldCA; ; {
var block *pem.Block
block, rest = pem.Decode(rest)
if block == nil {
break
}
if block.Type != "CERTIFICATE" {
continue
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil || cert.NotAfter.Before(now) {
continue
}
if _, dup := existing[string(cert.Raw)]; dup {
continue
}
result = append(result, pem.EncodeToMemory(block)...)
break // only carry forward one previous CA
}
return result
}
144 changes: 144 additions & 0 deletions internal/provider/kubernetes/secrets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,16 @@ package kubernetes

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -95,3 +103,139 @@ func TestCreateSecretsWhenUpgrade(t *testing.T) {
require.Len(t, created, 4)
})
}

// TestCreateOrUpdateSecretsBundlesCA verifies that when rotating certificates the
// old CA is bundled with the new CA so that components that haven't reloaded yet
// continue to be trusted during the transition window.
func TestCreateOrUpdateSecretsBundlesCA(t *testing.T) {
now := time.Now()
ca1 := makeTestCAPEM(t, now.Add(365*24*time.Hour))
ca2 := makeTestCAPEM(t, now.Add(365*24*time.Hour))
require.NotEqual(t, ca1, ca2)

// Seed the cluster with a secret carrying the old CA.
existing := newSecret(corev1.SecretTypeTLS, "envoy-gateway", "test-ns", map[string][]byte{
caCertificateKey: ca1,
corev1.TLSCertKey: []byte("old-cert"),
corev1.TLSPrivateKeyKey: []byte("old-key"),
})
cli := fakeclient.NewClientBuilder().WithObjects(&existing).Build()

// Rotate: present a new secret carrying the new CA.
rotated := newSecret(corev1.SecretTypeTLS, "envoy-gateway", "test-ns", map[string][]byte{
caCertificateKey: ca2,
corev1.TLSCertKey: []byte("new-cert"),
corev1.TLSPrivateKeyKey: []byte("new-key"),
})
updated, err := CreateOrUpdateSecrets(context.Background(), cli, []corev1.Secret{rotated}, true)
require.NoError(t, err)
require.Len(t, updated, 1)

bundle := updated[0].Data[caCertificateKey]

// The bundle must be valid PEM accepted by x509.
pool := x509.NewCertPool()
require.True(t, pool.AppendCertsFromPEM(bundle), "bundle must be valid PEM")

certs := decodePEMCerts(t, bundle)
require.Len(t, certs, 2, "bundle must contain exactly 2 certs: new CA + old CA")
assert.Equal(t, decodePEMCerts(t, ca2)[0].Raw, certs[0].Raw, "first cert in bundle must be the new CA")
assert.Equal(t, decodePEMCerts(t, ca1)[0].Raw, certs[1].Raw, "second cert in bundle must be the old CA")
}

// TestBundleCACerts covers the bundleCACerts helper directly.
func TestBundleCACerts(t *testing.T) {
now := time.Now()
ca1 := makeTestCAPEM(t, now.Add(365*24*time.Hour))
ca2 := makeTestCAPEM(t, now.Add(365*24*time.Hour))
expired := makeTestCAPEM(t, now.Add(-time.Second)) // already expired

t.Run("identical CAs return the original bytes unchanged", func(t *testing.T) {
result := bundleCACerts(ca1, ca1)
assert.Equal(t, ca1, result)
})

t.Run("different CAs are concatenated new-first", func(t *testing.T) {
result := bundleCACerts(ca2, ca1)
certs := decodePEMCerts(t, result)
require.Len(t, certs, 2)
assert.Equal(t, decodePEMCerts(t, ca2)[0].Raw, certs[0].Raw, "new CA must be first")
assert.Equal(t, decodePEMCerts(t, ca1)[0].Raw, certs[1].Raw, "old CA must be second")
})

t.Run("cert already present in newCA is not duplicated", func(t *testing.T) {
// Bundle containing ca2+ca1; applying bundleCACerts(ca2, bundle) must
// not add ca2 again.
bundle := bundleCACerts(ca2, ca1)
result := bundleCACerts(ca2, bundle)
certs := decodePEMCerts(t, result)
require.Len(t, certs, 2, "ca2 that is already in newCA must not be re-appended")
})

t.Run("expired cert from old bundle is excluded", func(t *testing.T) {
result := bundleCACerts(ca1, expired)
certs := decodePEMCerts(t, result)
require.Len(t, certs, 1, "expired cert must not be included in the bundle")
})

t.Run("bundle never exceeds two CAs across multiple rotations", func(t *testing.T) {
ca3 := makeTestCAPEM(t, now.Add(365*24*time.Hour))

// Rotation 1: CA1 -> CA2
after1 := bundleCACerts(ca2, ca1)
require.Len(t, decodePEMCerts(t, after1), 2)

// Rotation 2: CA2 -> CA3 (oldCA is the after1 bundle: CA2+CA1)
// Only CA2 (the head of after1) should be carried forward; CA1 is dropped.
after2 := bundleCACerts(ca3, after1)
certs := decodePEMCerts(t, after2)
require.Len(t, certs, 2, "bundle must not grow beyond 2 CAs after a second rotation")
assert.Equal(t, decodePEMCerts(t, ca3)[0].Raw, certs[0].Raw, "new CA must be first")
assert.Equal(t, decodePEMCerts(t, ca2)[0].Raw, certs[1].Raw, "only the immediately-previous CA is carried forward")
})
}

// makeTestCAPEM generates a minimal self-signed CA certificate with the given
// validity window and returns it as a PEM-encoded CERTIFICATE block.
func makeTestCAPEM(t *testing.T, notAfter time.Time) []byte {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

now := time.Now()
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
require.NoError(t, err)

template := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: "test-ca"},
NotBefore: now.Add(-time.Hour),
NotAfter: notAfter,
IsCA: true,
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageCertSign,
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
require.NoError(t, err)
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
}

// decodePEMCerts returns all x509 certificates decoded from a PEM bundle.
func decodePEMCerts(t *testing.T, data []byte) []*x509.Certificate {
t.Helper()
var certs []*x509.Certificate
for rest := data; len(rest) > 0; {
var block *pem.Block
block, rest = pem.Decode(rest)
if block == nil {
break
}
if block.Type != "CERTIFICATE" {
continue
}
cert, err := x509.ParseCertificate(block.Bytes)
require.NoError(t, err)
certs = append(certs, cert)
}
return certs
}
Loading