diff --git a/cmd/eks-node-viewer/main.go b/cmd/eks-node-viewer/main.go index 3f3d369..e41af70 100644 --- a/cmd/eks-node-viewer/main.go +++ b/cmd/eks-node-viewer/main.go @@ -23,6 +23,7 @@ import ( "os" "strings" + awsSdk "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" tea "github.com/charmbracelet/bubbletea" "k8s.io/apimachinery/pkg/labels" @@ -68,7 +69,9 @@ func main() { } ctx, cancel := context.WithCancel(context.Background()) - pprov := aws.NewStaticPricingProvider() + region, profile := client.GetAWSRegionAndProfile(flags.Kubeconfig, flags.Context) + + pprov := aws.NewStaticPricingProvider(region) style, err := model.ParseStyle(flags.Style) if err != nil { log.Fatalf("creating style, %s", err) @@ -85,7 +88,13 @@ func main() { } if !flags.DisablePricing { - sess := session.Must(session.NewSessionWithOptions(session.Options{SharedConfigState: session.SharedConfigEnable})) + sess := session.Must(session.NewSessionWithOptions( + session.Options{ + Config: awsSdk.Config{Region: ®ion}, + Profile: profile, + SharedConfigState: session.SharedConfigEnable, + }, + )) pprov = aws.NewPricingProvider(ctx, sess) } controller := client.NewController(cs, nodeClaimClient, m, nodeSelector, pprov) diff --git a/pkg/aws/pricing.go b/pkg/aws/pricing.go index fffd8a3..797c91a 100644 --- a/pkg/aws/pricing.go +++ b/pkg/aws/pricing.go @@ -128,10 +128,12 @@ func getStaticPrices(region string) map[ec2types.InstanceType]float64 { return InitialOnDemandPricesAWS["us-east-1"] } -func NewStaticPricingProvider() nvp.Provider { - region := os.Getenv("AWS_REGION") +func NewStaticPricingProvider(region string) nvp.Provider { if region == "" { - region = "us-east-1" + region := os.Getenv("AWS_REGION") + if region == "" { + region = "us-east-1" + } } return &pricingProvider{ diff --git a/pkg/client/client.go b/pkg/client/client.go index 888bc40..d23da74 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -17,6 +17,7 @@ package client import ( "strings" + "github.com/spf13/pflag" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes/scheme" @@ -60,9 +61,48 @@ func NewNodeClaims(kubeconfig, context string) (*rest.RESTClient, error) { return rest.RESTClientFor(&config) } -func getConfig(kubeconfig, context string) (*rest.Config, error) { +func GetAWSRegionAndProfile(kubeconfig, context string) (region, profile string) { + config := getClientConfig(kubeconfig, context) + raw, err := config.RawConfig() + if err != nil { + return "", "" + } + + if context == "" { + context = raw.CurrentContext + } + kubeContext := raw.Contexts[context] + if kubeContext == nil { + return "", "" + } + auth := raw.AuthInfos[kubeContext.AuthInfo] + if auth == nil || auth.Exec == nil { + return "", "" + } + + // use a flagset to parse the args from the exec config + // + flagSet := pflag.NewFlagSet("aws", pflag.ContinueOnError) + flagSet.ParseErrorsWhitelist.UnknownFlags = true + regionPtr := flagSet.String("region", "", "") + _ = flagSet.Parse(auth.Exec.Args) + + for _, env := range auth.Exec.Env { + if env.Name == "AWS_PROFILE" { + profile = env.Value + } + } + + return *regionPtr, profile +} + +func getClientConfig(kubeconfig, context string) clientcmd.ClientConfig { // use the current context in kubeconfig return clientcmd.NewNonInteractiveDeferredLoadingClientConfig( &clientcmd.ClientConfigLoadingRules{Precedence: strings.Split(kubeconfig, ":")}, - &clientcmd.ConfigOverrides{CurrentContext: context}).ClientConfig() + &clientcmd.ConfigOverrides{CurrentContext: context}) +} + +func getConfig(kubeconfig, context string) (*rest.Config, error) { + return getClientConfig(kubeconfig, context).ClientConfig() }