diff --git a/cmd/atlas/atlas.go b/cmd/atlas/atlas.go index 5b26d40382..e88e51e4bc 100644 --- a/cmd/atlas/atlas.go +++ b/cmd/atlas/atlas.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/AlecAivazis/survey/v2/core" + "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/commonerrors" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/root" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/telemetry" @@ -36,6 +37,8 @@ func execute(rootCmd *cobra.Command) { To learn more, see our documentation: https://www.mongodb.com/docs/atlas/cli/stable/connect-atlas-cli/` if cmd, err := rootCmd.ExecuteContextC(ctx); err != nil { + err := commonerrors.Check(err) + rootCmd.PrintErrln(rootCmd.ErrPrefix(), err) if !telemetry.StartedTrackingCommand() { telemetry.StartTrackingCommand(cmd, os.Args[1:]) } diff --git a/internal/cli/auth/login.go b/internal/cli/auth/login.go index ae53d97450..a2fd9a9601 100644 --- a/internal/cli/auth/login.go +++ b/internal/cli/auth/login.go @@ -22,6 +22,7 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli" + "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/commonerrors" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/require" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/flag" @@ -397,7 +398,7 @@ func (opts *LoginOpts) LoginPreRun(ctx context.Context) func() error { // clean up any expired or invalid tokens opts.config.Set(config.AccessTokenField, "") - if !errors.Is(err, cli.ErrInvalidRefreshToken) { + if !commonerrors.IsInvalidRefreshToken(err) { return err } } diff --git a/internal/cli/backup/restores/start.go b/internal/cli/backup/restores/start.go index 2875289285..6a7ddb85bc 100644 --- a/internal/cli/backup/restores/start.go +++ b/internal/cli/backup/restores/start.go @@ -73,7 +73,7 @@ func (opts *StartOpts) Run() error { if opts.isFlexCluster { r, err := opts.store.CreateRestoreFlexClusterJobs(opts.ConfigProjectID(), opts.clusterName, opts.newFlexBackupRestoreJobCreate()) if err != nil { - return commonerrors.Check(err) + return err } return opts.Print(r) } @@ -81,7 +81,7 @@ func (opts *StartOpts) Run() error { request := opts.newCloudProviderSnapshotRestoreJob() restoreJob, err := opts.store.CreateRestoreJobs(opts.ConfigProjectID(), opts.clusterName, request) if err != nil { - return commonerrors.Check(err) + return err } return opts.Print(restoreJob) diff --git a/internal/cli/backup/snapshots/create.go b/internal/cli/backup/snapshots/create.go index 0221066390..52d5041f34 100644 --- a/internal/cli/backup/snapshots/create.go +++ b/internal/cli/backup/snapshots/create.go @@ -19,7 +19,6 @@ import ( "fmt" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli" - "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/commonerrors" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/require" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/flag" @@ -59,7 +58,7 @@ func (opts *CreateOpts) Run() error { r, err := opts.store.CreateSnapshot(opts.ConfigProjectID(), opts.clusterName, createRequest) if err != nil { - return commonerrors.Check(err) + return err } return opts.Print(r) } diff --git a/internal/cli/clusters/advancedsettings/update.go b/internal/cli/clusters/advancedsettings/update.go index e542d81c8e..efb5a91ae6 100644 --- a/internal/cli/clusters/advancedsettings/update.go +++ b/internal/cli/clusters/advancedsettings/update.go @@ -18,7 +18,6 @@ import ( "context" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli" - "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/commonerrors" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/require" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/flag" @@ -67,7 +66,7 @@ func (opts *UpdateOpts) initStore(ctx context.Context) func() error { func (opts *UpdateOpts) Run() error { r, err := opts.store.UpdateAtlasClusterConfigurationOptions(opts.ConfigProjectID(), opts.name, opts.newProcessArgs()) if err != nil { - return commonerrors.Check(err) + return err } return opts.Print(r) diff --git a/internal/cli/clusters/describe.go b/internal/cli/clusters/describe.go index 8777e2b585..5d6d6dae37 100644 --- a/internal/cli/clusters/describe.go +++ b/internal/cli/clusters/describe.go @@ -19,7 +19,6 @@ import ( "fmt" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli" - "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/commonerrors" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/require" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/flag" @@ -83,12 +82,12 @@ func (opts *DescribeOpts) RunFlexCluster(err error) error { } if *apiError.ErrorCode != cannotUseFlexWithClusterApisErrorCode { - return commonerrors.Check(err) + return err } r, err := opts.store.FlexCluster(opts.ConfigProjectID(), opts.name) if err != nil { - return commonerrors.Check(err) + return err } return opts.Print(r) diff --git a/internal/cli/clusters/pause.go b/internal/cli/clusters/pause.go index 91027897a3..144548fd56 100644 --- a/internal/cli/clusters/pause.go +++ b/internal/cli/clusters/pause.go @@ -19,7 +19,6 @@ import ( "fmt" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli" - "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/commonerrors" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/require" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/flag" @@ -60,14 +59,14 @@ func (opts *PauseOpts) Run() error { if isIndependentShardScaling(opts.autoScalingMode) { r, err := opts.store.PauseClusterLatest(opts.ConfigProjectID(), opts.name) if err != nil { - return commonerrors.Check(err) + return err } return opts.Print(r) } r, err := opts.store.PauseCluster(opts.ConfigProjectID(), opts.name) if err != nil { - return commonerrors.Check(err) + return err } return opts.Print(r) } diff --git a/internal/cli/clusters/start.go b/internal/cli/clusters/start.go index 5abaca7abf..68c7aebc0f 100644 --- a/internal/cli/clusters/start.go +++ b/internal/cli/clusters/start.go @@ -19,7 +19,6 @@ import ( "fmt" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli" - "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/commonerrors" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/require" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/flag" @@ -60,14 +59,14 @@ func (opts *StartOpts) Run() error { if isIndependentShardScaling(opts.autoScalingMode) { r, err := opts.store.StartClusterLatest(opts.ConfigProjectID(), opts.name) if err != nil { - return commonerrors.Check(err) + return err } return opts.Print(r) } r, err := opts.store.StartCluster(opts.ConfigProjectID(), opts.name) if err != nil { - return commonerrors.Check(err) + return err } return opts.Print(r) } diff --git a/internal/cli/clusters/update.go b/internal/cli/clusters/update.go index 0fe0d62ce3..bc92d5b4c1 100644 --- a/internal/cli/clusters/update.go +++ b/internal/cli/clusters/update.go @@ -211,7 +211,7 @@ func (opts *UpdateOpts) RunDedicatedClusterWideScaling() error { r, err := opts.store.UpdateCluster(opts.ConfigProjectID(), opts.name, cluster) if err != nil { - return commonerrors.Check(err) + return err } return opts.Print(r) diff --git a/internal/cli/commonerrors/errors.go b/internal/cli/commonerrors/errors.go index aacb917402..a4fa11c127 100644 --- a/internal/cli/commonerrors/errors.go +++ b/internal/cli/commonerrors/errors.go @@ -16,41 +16,113 @@ package commonerrors import ( "errors" + "net/http" - "go.mongodb.org/atlas-sdk/v20250312005/admin" + atlasClustersPinned "go.mongodb.org/atlas-sdk/v20240530005/admin" + atlasv2 "go.mongodb.org/atlas-sdk/v20250312005/admin" + atlas "go.mongodb.org/atlas/mongodbatlas" ) var ( errClusterUnsupported = errors.New("atlas supports this command only for M10+ clusters. You can upgrade your cluster by running the 'atlas cluster upgrade' command") errOutsideVPN = errors.New("forbidden action outside access allow list, if you are a MongoDB employee double check your VPN connection") errAsymmetricShardUnsupported = errors.New("trying to run a cluster wide scaling command on an independent shard scaling cluster. Use --autoScalingMode 'independentShardScaling' instead") + ErrUnauthorized = errors.New(`this action requires authentication + +To log in using your Atlas username and password, run: atlas auth login +To set credentials using API keys, run: atlas config init`) + ErrInvalidRefreshToken = errors.New(`session expired + +Please note that your session expires periodically. +If you use Atlas CLI for automation, see https://www.mongodb.com/docs/atlas/cli/stable/atlas-cli-automate/ for best practices. +To login, run: atlas auth login`) ) const ( - asymmetricShardUnsupportedErrorCode = "ASYMMETRIC_SHARD_UNSUPPORTED" + unknownErrorCode = "UNKNOWN_ERROR" + asymmetricShardUnsupportedErrorCode = "ASYMMETRIC_SHARD_UNSUPPORTED" + tenantClusterUpdateUnsupportedErrorCode = "TENANT_CLUSTER_UPDATE_UNSUPPORTED" + globalUserOutsideSubnetErrorCode = "GLOBAL_USER_OUTSIDE_SUBNET" + unauthorizedErrorCode = "UNAUTHORIZED" + invalidRefreshTokenErrorCode = "INVALID_REFRESH_TOKEN" ) +// Check checks the error and returns a more user-friendly error message if applicable. func Check(err error) error { if err == nil { return nil } - apiError, ok := admin.AsError(err) - if ok { - switch apiError.GetErrorCode() { - case "TENANT_CLUSTER_UPDATE_UNSUPPORTED": - return errClusterUnsupported - case "GLOBAL_USER_OUTSIDE_SUBNET": - return errOutsideVPN - case asymmetricShardUnsupportedErrorCode: - return errAsymmetricShardUnsupported - } + apiErrorCode := getErrorCode(err) + + switch apiErrorCode { + case unauthorizedErrorCode: + return ErrUnauthorized + case invalidRefreshTokenErrorCode: + return ErrInvalidRefreshToken + case tenantClusterUpdateUnsupportedErrorCode: + return errClusterUnsupported + case globalUserOutsideSubnetErrorCode: + return errOutsideVPN + case asymmetricShardUnsupportedErrorCode: + return errAsymmetricShardUnsupported } + + apiError := getError(err) // some `Unauthorized` errors do not have an error code, so we check the HTTP status code + + if apiError == http.StatusUnauthorized { + return ErrUnauthorized + } + return err } +// getErrorCode extracts the error code from the error if it is an Atlas error. +// This function checks for v2 SDK, the pinned clusters SDK and the old SDK errors. +// If the error is not any of these Atlas errors, it returns "UNKNOWN_ERROR". +func getErrorCode(err error) string { + if err == nil { + return unknownErrorCode + } + + var atlasErr *atlas.ErrorResponse + if errors.As(err, &atlasErr) { + return atlasErr.ErrorCode + } + if sdkError, ok := atlasv2.AsError(err); ok { + return sdkError.ErrorCode + } + if sdkPinnedError, ok := atlasClustersPinned.AsError(err); ok { + return sdkPinnedError.GetErrorCode() + } + + return unknownErrorCode +} + +// getError extracts the HTTP error code from the error if it is an Atlas error. +// This function checks for v2 SDK, the pinned clusters SDK and the old SDK errors. +// If the error is not any of these Atlas errors, it returns 0. +func getError(err error) int { + if err == nil { + return 0 + } + + var atlasErr *atlas.ErrorResponse + if errors.As(err, &atlasErr) { + return atlasErr.HTTPCode + } + if apiError, ok := atlasv2.AsError(err); ok { + return apiError.GetError() + } + if apiPinnedError, ok := atlasClustersPinned.AsError(err); ok { + return apiPinnedError.GetError() + } + + return 0 +} + func IsAsymmetricShardUnsupported(err error) bool { - apiError, ok := admin.AsError(err) + apiError, ok := atlasv2.AsError(err) if !ok { return false } @@ -58,9 +130,14 @@ func IsAsymmetricShardUnsupported(err error) bool { } func IsCannotUseFlexWithClusterApis(err error) bool { - apiError, ok := admin.AsError(err) + apiError, ok := atlasv2.AsError(err) if !ok { return false } return apiError.GetErrorCode() == "CANNOT_USE_FLEX_CLUSTER_IN_CLUSTER_API" } + +func IsInvalidRefreshToken(err error) bool { + errCode := getErrorCode(err) + return errCode == invalidRefreshTokenErrorCode +} diff --git a/internal/cli/commonerrors/errors_test.go b/internal/cli/commonerrors/errors_test.go index 8e7c56da20..cb7b1c1095 100644 --- a/internal/cli/commonerrors/errors_test.go +++ b/internal/cli/commonerrors/errors_test.go @@ -20,17 +20,23 @@ import ( "errors" "testing" - "go.mongodb.org/atlas-sdk/v20250312005/admin" + atlasClustersPinned "go.mongodb.org/atlas-sdk/v20240530005/admin" + atlasv2 "go.mongodb.org/atlas-sdk/v20250312005/admin" + atlas "go.mongodb.org/atlas/mongodbatlas" ) func TestCheck(t *testing.T) { dummyErr := errors.New("dummy error") - skderr := &admin.GenericOpenAPIError{} - skderr.SetModel(admin.ApiError{ErrorCode: "TENANT_CLUSTER_UPDATE_UNSUPPORTED"}) + skderr := &atlasv2.GenericOpenAPIError{} + skderr.SetModel(atlasv2.ApiError{ErrorCode: tenantClusterUpdateUnsupportedErrorCode}) - asymmetricShardErr := &admin.GenericOpenAPIError{} - asymmetricShardErr.SetModel(admin.ApiError{ErrorCode: asymmetricShardUnsupportedErrorCode}) + asymmetricShardErr := &atlasv2.GenericOpenAPIError{} + asymmetricShardErr.SetModel(atlasv2.ApiError{ErrorCode: asymmetricShardUnsupportedErrorCode}) + + unauthErr := &atlas.ErrorResponse{ErrorCode: unauthorizedErrorCode} + + invalidRefreshTokenErr := &atlas.ErrorResponse{ErrorCode: invalidRefreshTokenErrorCode} testCases := []struct { name string @@ -57,6 +63,16 @@ func TestCheck(t *testing.T) { err: asymmetricShardErr, want: errAsymmetricShardUnsupported, }, + { + name: "unauthorized error", + err: unauthErr, + want: ErrUnauthorized, + }, + { + name: "invalid refresh token error", + err: invalidRefreshTokenErr, + want: ErrInvalidRefreshToken, + }, } for _, tc := range testCases { @@ -67,3 +83,98 @@ func TestCheck(t *testing.T) { }) } } + +func TestGetError(t *testing.T) { + dummyErr := errors.New("dummy error") + + unauthorizedCode := 401 + forbiddenCode := 403 + notFoundCode := 404 + + atlasErr := &atlas.ErrorResponse{HTTPCode: unauthorizedCode} + atlasv2Err := &atlasv2.GenericOpenAPIError{} + atlasv2Err.SetModel(atlasv2.ApiError{Error: forbiddenCode}) + atlasClustersPinnedErr := &atlasClustersPinned.GenericOpenAPIError{} + atlasClustersPinnedErr.SetModel(atlasClustersPinned.ApiError{Error: ¬FoundCode}) + + testCases := []struct { + name string + err error + want int + }{ + { + name: "atlas unauthorized error", + err: atlasErr, + want: unauthorizedCode, + }, + { + name: "atlasv2 forbidden error", + err: atlasv2Err, + want: forbiddenCode, + }, + { + name: "atlasClusterPinned not found error", + err: atlasClustersPinnedErr, + want: notFoundCode, + }, + { + name: "arbitrary error", + err: dummyErr, + want: 0, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if got := getError(tc.err); got != tc.want { + t.Errorf("GetError(%v) = %v, want %v", tc.err, got, tc.want) + } + }) + } +} + +func TestGetErrorCode(t *testing.T) { + dummyErr := errors.New("dummy error") + + atlasErr := &atlas.ErrorResponse{ErrorCode: invalidRefreshTokenErrorCode} + atlasv2Err := &atlasv2.GenericOpenAPIError{} + atlasv2Err.SetModel(atlasv2.ApiError{ErrorCode: tenantClusterUpdateUnsupportedErrorCode}) + atlasClustersPinnedErr := &atlasClustersPinned.GenericOpenAPIError{} + asymmetricCode := asymmetricShardUnsupportedErrorCode + atlasClustersPinnedErr.SetModel(atlasClustersPinned.ApiError{ErrorCode: &asymmetricCode}) + + testCases := []struct { + name string + err error + want string + }{ + { + name: "atlas error", + err: atlasErr, + want: invalidRefreshTokenErrorCode, + }, + { + name: "atlasv2 error", + err: atlasv2Err, + want: tenantClusterUpdateUnsupportedErrorCode, + }, + { + name: "atlasClusterPinned error", + err: atlasClustersPinnedErr, + want: asymmetricShardUnsupportedErrorCode, + }, + { + name: "arbitrary error", + err: dummyErr, + want: unknownErrorCode, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if got := getErrorCode(tc.err); got != tc.want { + t.Errorf("GetErrorCode(%v) = %v, want %v", tc.err, got, tc.want) + } + }) + } +} diff --git a/internal/cli/default_setter_opts.go b/internal/cli/default_setter_opts.go index bcf775e546..0d68f65dd5 100644 --- a/internal/cli/default_setter_opts.go +++ b/internal/cli/default_setter_opts.go @@ -22,7 +22,6 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/briandowns/spinner" - "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/commonerrors" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/pointer" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/prompt" @@ -92,7 +91,6 @@ func (opts *DefaultSetterOpts) projects() (ids, names []string, err error) { projects, err = opts.Store.GetOrgProjects(opts.OrgID, list) } if err != nil { - err = commonerrors.Check(err) if atlasErr, ok := atlasv2.AsError(err); ok && atlasErr.GetError() == 404 { return nil, nil, errNoResults } @@ -120,7 +118,7 @@ func (opts *DefaultSetterOpts) orgs(filter string) (results []atlasv2.AtlasOrgan if atlasErr, ok := atlasv2.AsError(err); ok && atlasErr.GetError() == 404 { return nil, errNoResults } - return nil, commonerrors.Check(err) + return nil, err } if orgs == nil { return nil, errNoResults diff --git a/internal/cli/refresher_opts.go b/internal/cli/refresher_opts.go index 8b3bbebac1..08f1091030 100644 --- a/internal/cli/refresher_opts.go +++ b/internal/cli/refresher_opts.go @@ -16,8 +16,6 @@ package cli import ( "context" - "errors" - "fmt" "net/http" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config" @@ -56,8 +54,6 @@ func (opts *RefresherOpts) WithFlow(f Refresher) { opts.flow = f } -var ErrInvalidRefreshToken = errors.New("session expired") - func (opts *RefresherOpts) RefreshAccessToken(ctx context.Context) error { current, err := config.Token() if current == nil { @@ -69,16 +65,6 @@ func (opts *RefresherOpts) RefreshAccessToken(ctx context.Context) error { } t, _, err := opts.flow.RefreshToken(ctx, config.RefreshToken()) if err != nil { - var target *atlas.ErrorResponse - if errors.As(err, &target) && target.ErrorCode == "INVALID_REFRESH_TOKEN" { - return fmt.Errorf( - `%w - -Please note that your session expires periodically. -If you use Atlas CLI for automation, see https://www.mongodb.com/docs/atlas/cli/stable/atlas-cli-automate/ for best practices. -To login, run: atlas auth login`, - ErrInvalidRefreshToken) - } return err } config.SetAccessToken(t.AccessToken) diff --git a/internal/cli/root/builder.go b/internal/cli/root/builder.go index 0899064f0b..9d533339f9 100644 --- a/internal/cli/root/builder.go +++ b/internal/cli/root/builder.go @@ -113,7 +113,8 @@ Use the --help flag with any command for more info on that command.`, Example: ` # Display the help menu for the config command: atlas config --help `, - SilenceUsage: true, + SilenceUsage: true, + SilenceErrors: true, Annotations: map[string]string{ "toc": "true", }, diff --git a/internal/cli/serverless/backup/restores/create.go b/internal/cli/serverless/backup/restores/create.go index 6e2e24d684..fe952c4a7a 100644 --- a/internal/cli/serverless/backup/restores/create.go +++ b/internal/cli/serverless/backup/restores/create.go @@ -20,7 +20,6 @@ import ( "fmt" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli" - "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/commonerrors" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/require" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/flag" @@ -73,7 +72,7 @@ func (opts *CreateOpts) Run() error { r, err := opts.store.ServerlessCreateRestoreJobs(opts.ConfigProjectID(), opts.clusterName, request) if err != nil { - return commonerrors.Check(err) + return err } return opts.Print(r) diff --git a/internal/cli/setup/setup_cmd.go b/internal/cli/setup/setup_cmd.go index 911013408d..0ffe0ce364 100644 --- a/internal/cli/setup/setup_cmd.go +++ b/internal/cli/setup/setup_cmd.go @@ -28,6 +28,7 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/auth" + "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/commonerrors" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/require" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/compass" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config" @@ -595,7 +596,7 @@ func (opts *Opts) PreRun(ctx context.Context) error { // The error is useful in other components that call `validate.NoAPIKeys()` return nil } - if err := opts.register.RefreshAccessToken(ctx); err != nil && errors.Is(err, cli.ErrInvalidRefreshToken) { + if err := opts.register.RefreshAccessToken(ctx); err != nil && commonerrors.IsInvalidRefreshToken(err) { opts.skipLogin = false return nil } diff --git a/internal/store/store.go b/internal/store/store.go index 1dcd93629c..7429c45c89 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -71,8 +71,7 @@ func (s *Store) httpClient(httpTransport http.RoundTripper) (*http.Client, error return &http.Client{Transport: tr}, nil default: - tr := &transport.AuthRequiredRoundTripper{Base: httpTransport} - return &http.Client{Transport: tr}, nil + return &http.Client{Transport: httpTransport}, nil } } diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 7f5af61edf..60c62da7c6 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -15,7 +15,6 @@ package transport import ( - "fmt" "net" "net/http" "time" @@ -23,7 +22,6 @@ import ( "github.com/mongodb-forks/digest" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/oauth" - "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/validate" atlasauth "go.mongodb.org/atlas/auth" ) @@ -112,24 +110,3 @@ func (tr *tokenTransport) RoundTrip(req *http.Request) (*http.Response, error) { return tr.base.RoundTrip(req) } - -type AuthRequiredRoundTripper struct { - Base http.RoundTripper -} - -func (tr *AuthRequiredRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - resp, err := tr.Base.RoundTrip(req) - if resp != nil && resp.StatusCode == http.StatusUnauthorized { - return nil, fmt.Errorf( - `%w - -To log in using your Atlas username and password, run: atlas auth login -To set credentials using API keys, run: atlas config init`, - validate.ErrMissingCredentials, - ) - } - if err != nil { - return nil, err - } - return resp, nil -} diff --git a/internal/validate/validate.go b/internal/validate/validate.go index 514a4bcdd5..df23741f78 100644 --- a/internal/validate/validate.go +++ b/internal/validate/validate.go @@ -24,6 +24,7 @@ import ( "slices" "strings" + "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/commonerrors" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config" ) @@ -95,8 +96,6 @@ func ObjectID(s string) error { return nil } -var ErrMissingCredentials = errors.New("this action requires authentication") - // Credentials validates public and private API keys have been set. func Credentials() error { if t, err := config.Token(); t != nil { @@ -106,13 +105,7 @@ func Credentials() error { return nil } - return fmt.Errorf( - `%w - -To log in using your Atlas username and password, run: atlas auth login -To set credentials using API keys, run: atlas config init`, - ErrMissingCredentials, - ) + return commonerrors.ErrUnauthorized } var ErrAlreadyAuthenticatedAPIKeys = errors.New("already authenticated with an API key") diff --git a/internal/validate/validate_test.go b/internal/validate/validate_test.go index a52784f0af..7c8f896ae0 100644 --- a/internal/validate/validate_test.go +++ b/internal/validate/validate_test.go @@ -169,7 +169,7 @@ func TestCredentials(t *testing.T) { }) } -func TestNoAPIKeyss(t *testing.T) { +func TestNoAPIKeys(t *testing.T) { t.Run("no credentials", func(t *testing.T) { if err := NoAPIKeys(); err != nil { t.Fatalf("NoAPIKeys() unexpected error %v\n", err) diff --git a/test/e2e/setupfailure/setup_failure_test.go b/test/e2e/setupfailure/setup_failure_test.go index 3c7c228bfc..b16b325108 100644 --- a/test/e2e/setupfailure/setup_failure_test.go +++ b/test/e2e/setupfailure/setup_failure_test.go @@ -47,7 +47,7 @@ func TestSetupFailureFlow(t *testing.T) { cmd.Env = os.Environ() resp, err := cmd.CombinedOutput() req.Error(err) - assert.Contains(t, string(resp), "Unauthorized", "Expected unauthorized error due to invalid public key.") + assert.Contains(t, string(resp), "this action requires authentication") }) g.Run("Invalid Private Key", func(t *testing.T) { //nolint:thelper // g.Run replaces t.Run @@ -60,7 +60,7 @@ func TestSetupFailureFlow(t *testing.T) { cmd.Env = os.Environ() resp, err := cmd.CombinedOutput() req.Error(err) - assert.Contains(t, string(resp), "Unauthorized", "Expected unauthorized error due to invalid private key.") + assert.Contains(t, string(resp), "this action requires authentication") }) g.Run("Invalid Project ID", func(t *testing.T) { //nolint:thelper // g.Run replaces t.Run