Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions pkg/credentials/iam_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ type IAM struct {
// Region configurable custom region for STS
Region string

// ExpiryWindow allows customizing the window before credentials expire
// when they should be refreshed. Defaults to DefaultExpiryWindow.
ExpiryWindow time.Duration

// Support for container authorization token https://docs.aws.amazon.com/sdkref/latest/guide/feature-container-credentials.html
Container struct {
AuthorizationToken string
Expand Down Expand Up @@ -94,6 +98,23 @@ func NewIAM(endpoint string) *Credentials {
})
}

// IAMConfig contains configuration options for the IAM credentials provider
type IAMConfig struct {
// ExpiryWindow allows customizing the window before credentials expire
// when they should be refreshed. Use DefaultExpiryWindow for the default behavior.
ExpiryWindow time.Duration
}

// NewIAMWithConfig returns a pointer to a new Credentials object wrapping the IAM
// with custom configuration.
// The config parameter allows customizing various aspects of the IAM provider.
func NewIAMWithConfig(endpoint string, config IAMConfig) *Credentials {
return New(&IAM{
Endpoint: endpoint,
ExpiryWindow: config.ExpiryWindow,
})
}

// RetrieveWithCredContext is like Retrieve with Cred Context
func (m *IAM) RetrieveWithCredContext(cc *CredContext) (Value, error) {
if cc == nil {
Expand Down Expand Up @@ -220,8 +241,11 @@ func (m *IAM) RetrieveWithCredContext(cc *CredContext) (Value, error) {
if err != nil {
return Value{}, err
}
// Expiry window is set to 10secs.
m.SetExpiration(roleCreds.Expiration, DefaultExpiryWindow)
// Use custom expiry window if set, otherwise use default
if m.ExpiryWindow == 0 {
m.ExpiryWindow = DefaultExpiryWindow
}
m.SetExpiration(roleCreds.Expiration, m.ExpiryWindow)

return Value{
AccessKeyID: roleCreds.AccessKeyID,
Expand Down
181 changes: 181 additions & 0 deletions pkg/credentials/iam_aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,184 @@ func TestIMDSv1Blocked(t *testing.T) {
t.Errorf("Unexpected IMDSv2 failure %s", err)
}
}

func TestIAMCustomExpiryWindow(t *testing.T) {
server := initIMDSv2Server("2014-12-16T01:51:37Z", false)
defer server.Close()

// Test with custom expiry window of 5 minutes
customWindow := 5 * time.Minute
p := &IAM{
Endpoint: server.URL,
ExpiryWindow: customWindow,
}

// Set a known current time for predictable testing
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 21, 0, 0, 0, time.UTC)
}

// retrieve credentials - triggers initial expiration calculation
creds, err := p.RetrieveWithCredContext(defaultCredContext)
if err != nil {
t.Fatal(err)
}

if creds.AccessKeyID != "accessKey" {
t.Errorf("Expected \"accessKey\", got %s", creds.AccessKeyID)
}

// Verify that the custom expiry window was used
// The expiration time should be: original expiration - custom window
// Original: 2014-12-16T01:51:37Z
// Custom window: 5 minutes
// Expected expiry: 2014-12-16T01:46:37Z
expectedExpiry := time.Date(2014, 12, 16, 1, 46, 37, 0, time.UTC)
if !p.expiration.Equal(expectedExpiry) {
t.Errorf("Expected expiration %v, got %v", expectedExpiry, p.expiration)
}

// Credentials should not be expired at current time (2014-12-15 21:00:00)
if p.IsExpired() {
t.Error("Expected creds to not be expired with custom window.")
}

// Move time forward to just before expiry
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 16, 1, 46, 0, 0, time.UTC)
}
if p.IsExpired() {
t.Error("Expected creds to not be expired yet.")
}

// Move time forward past the custom expiry window
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 16, 1, 47, 0, 0, time.UTC)
}
if !p.IsExpired() {
t.Error("Expected creds to be expired after custom window.")
}
}

func TestIAMDefaultExpiryWindow(t *testing.T) {
server := initIMDSv2Server("2014-12-16T01:51:37Z", false)
defer server.Close()

// Test with default expiry window (should use 80% rule)
p := &IAM{
Endpoint: server.URL,
ExpiryWindow: DefaultExpiryWindow,
}

p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 21, 0, 0, 0, time.UTC)
}

// retrieve credentials - triggers initial expiration calculation
creds, err := p.RetrieveWithCredContext(defaultCredContext)
if err != nil {
t.Fatal(err)
}

if creds.AccessKeyID != "accessKey" {
t.Errorf("Expected \"accessKey\", got %s", creds.AccessKeyID)
}

// With default window, expiry should be calculated as:
// expiration - (80% of time until expiration)
// Time from current (2014-12-15 21:00:00) to expiration (2014-12-16 01:51:37) = 4h 51m 37s = 17497s
// 80% of that = 13997.6s ≈ 3h 53m 17.6s
// So expiry should be around: 2014-12-16 01:51:37 - 3h 53m 17.6s = 2014-12-15 21:58:19.4
// We'll check it's expired before the actual expiration time
originalExpiration := time.Date(2014, 12, 16, 1, 51, 37, 0, time.UTC)
if !p.expiration.Before(originalExpiration) {
t.Errorf("Expected expiration to be before original expiration time with default window")
}

// Credentials should not be expired at current time
if p.IsExpired() {
t.Error("Expected creds to not be expired initially.")
}
}

func TestNewIAMWithConfig(t *testing.T) {
server := initIMDSv2Server("2014-12-16T01:51:37Z", false)
defer server.Close()

// Test NewIAMWithConfig with custom expiry window
customWindow := 10 * time.Minute
config := IAMConfig{
ExpiryWindow: customWindow,
}

creds := NewIAMWithConfig(server.URL, config)
if creds == nil {
t.Fatal("Expected non-nil credentials")
}

// Verify the provider is properly configured
provider, ok := creds.provider.(*IAM)
if !ok {
t.Fatal("Expected provider to be *IAM")
}

if provider.Endpoint != server.URL {
t.Errorf("Expected endpoint %s, got %s", server.URL, provider.Endpoint)
}

if provider.ExpiryWindow != customWindow {
t.Errorf("Expected expiry window %v, got %v", customWindow, provider.ExpiryWindow)
}

// Set a known current time
provider.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 21, 0, 0, 0, time.UTC)
}

// Retrieve credentials and verify custom window is applied
value, err := creds.GetWithContext(defaultCredContext)
if err != nil {
t.Fatal(err)
}

if value.AccessKeyID != "accessKey" {
t.Errorf("Expected \"accessKey\", got %s", value.AccessKeyID)
}

// Verify expiration is set with custom window
expectedExpiry := time.Date(2014, 12, 16, 1, 41, 37, 0, time.UTC)
if !provider.expiration.Equal(expectedExpiry) {
t.Errorf("Expected expiration %v, got %v", expectedExpiry, provider.expiration)
}
}

func TestIAMZeroExpiryWindowUsesDefault(t *testing.T) {
server := initIMDSv2Server("2014-12-16T01:51:37Z", false)
defer server.Close()

// Test that zero expiry window falls back to default
p := &IAM{
Endpoint: server.URL,
ExpiryWindow: 0, // Explicitly set to zero
}

p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 21, 0, 0, 0, time.UTC)
}

_, err := p.RetrieveWithCredContext(defaultCredContext)
if err != nil {
t.Fatal(err)
}

// After retrieve, ExpiryWindow should be set to DefaultExpiryWindow
if p.ExpiryWindow != DefaultExpiryWindow {
t.Errorf("Expected ExpiryWindow to be DefaultExpiryWindow, got %v", p.ExpiryWindow)
}

// Verify default behavior (80% rule) is applied
originalExpiration := time.Date(2014, 12, 16, 1, 51, 37, 0, time.UTC)
if !p.expiration.Before(originalExpiration) {
t.Error("Expected expiration to be before original expiration time with default window")
}
}