Skip to content

Commit 66d2080

Browse files
authored
fix: clear credentials before auth (#406)
1 parent d0cb329 commit 66d2080

File tree

2 files changed

+79
-15
lines changed

2 files changed

+79
-15
lines changed

pkg/local_workflows/auth_workflow.go

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ func authEntryPoint(invocationCtx workflow.InvocationContext, _ []workflow.Data)
6060
logger := invocationCtx.GetEnhancedLogger()
6161
engine := invocationCtx.GetEngine()
6262

63+
config.ClearCache()
64+
6365
httpClient := invocationCtx.GetNetworkAccess().GetUnauthorizedHttpClient()
6466
authenticator := auth.NewOAuth2AuthenticatorWithOpts(
6567
config,
@@ -107,10 +109,13 @@ func entryPointDI(invocationCtx workflow.InvocationContext, logger *zerolog.Logg
107109
logger.Printf("Authentication Type: %s", authType)
108110
analytics.AddExtensionStringValue(authTypeParameter, authType)
109111

110-
if strings.EqualFold(authType, auth.AUTH_TYPE_OAUTH) { // OAUTH flow
111-
logger.Printf("Unset legacy token key %q from config", configuration.AUTHENTICATION_TOKEN)
112-
config.Unset(configuration.AUTHENTICATION_TOKEN)
112+
existingSnykToken := config.GetString(configuration.AUTHENTICATION_TOKEN)
113+
// always attempt to clear existing tokens before triggering auth
114+
logger.Print("Unset existing auth keys")
115+
config.Unset(configuration.AUTHENTICATION_TOKEN)
116+
config.Unset(auth.CONFIG_KEY_OAUTH_TOKEN)
113117

118+
if strings.EqualFold(authType, auth.AUTH_TYPE_OAUTH) { // OAUTH flow
114119
headless := config.GetBool(headlessFlag)
115120
logger.Printf("Headless: %v", headless)
116121

@@ -125,25 +130,19 @@ func entryPointDI(invocationCtx workflow.InvocationContext, logger *zerolog.Logg
125130
}
126131
} else if strings.EqualFold(authType, auth.AUTH_TYPE_PAT) { // PAT flow
127132
engine.GetConfiguration().PersistInStorage(auth.CONFIG_KEY_TOKEN)
128-
129-
oldToken := config.GetString(configuration.AUTHENTICATION_TOKEN)
130133
pat := config.GetString(ConfigurationNewAuthenticationToken)
131134

132-
logger.Print("Unset existing auth keys from config")
133-
config.Unset(auth.CONFIG_KEY_OAUTH_TOKEN)
134-
config.Unset(configuration.AUTHENTICATION_TOKEN)
135-
136135
logger.Print("Validating pat")
137136
whoamiConfig := config.Clone()
138-
// we don't want to use the cache here, so this is a workaround
139137
whoamiConfig.ClearCache()
138+
// we don't want to use the cache here, so this is a workaround
140139
whoamiConfig.Set(configuration.FLAG_EXPERIMENTAL, true)
141140
whoamiConfig.Set(configuration.AUTHENTICATION_TOKEN, pat)
142141
_, whoamiErr := engine.InvokeWithConfig(workflow.NewWorkflowIdentifier("whoami"), whoamiConfig)
143142
if whoamiErr != nil {
144143
// reset config file
145-
if len(oldToken) > 0 {
146-
config.Set(auth.CONFIG_KEY_TOKEN, oldToken)
144+
if len(existingSnykToken) > 0 {
145+
config.Set(auth.CONFIG_KEY_TOKEN, existingSnykToken)
147146
}
148147
return whoamiErr
149148
}
@@ -158,9 +157,6 @@ func entryPointDI(invocationCtx workflow.InvocationContext, logger *zerolog.Logg
158157
logger.Debug().Err(err).Msg("Failed to output authenticated message")
159158
}
160159
} else { // LEGACY flow
161-
logger.Printf("Unset oauth key %q from config", auth.CONFIG_KEY_OAUTH_TOKEN)
162-
config.Unset(auth.CONFIG_KEY_OAUTH_TOKEN)
163-
164160
config.Set(configuration.RAW_CMD_ARGS, os.Args[1:])
165161
config.Set(configuration.WORKFLOW_USE_STDIO, true)
166162
config.Set(configuration.AUTHENTICATION_TOKEN, "") // clear token to avoid using it during authentication

pkg/local_workflows/auth_workflow_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,74 @@ func Test_pat(t *testing.T) {
163163
})
164164
}
165165

166+
func Test_clearAllCredentialsBeforeAuth(t *testing.T) {
167+
mockCtl := gomock.NewController(t)
168+
defer mockCtl.Finish()
169+
170+
logContent := &bytes.Buffer{}
171+
logger := zerolog.New(logContent)
172+
analytics := analytics.New()
173+
engine := mocks.NewMockEngine(mockCtl)
174+
authenticator := mocks.NewMockAuthenticator(mockCtl)
175+
176+
testCases := []struct {
177+
name string
178+
authType string
179+
setupMocks func()
180+
}{
181+
{
182+
name: "OAuth flow clears all credentials",
183+
authType: auth.AUTH_TYPE_OAUTH,
184+
setupMocks: func() {
185+
authenticator.EXPECT().Authenticate().Return(nil)
186+
},
187+
},
188+
{
189+
name: "PAT flow clears all credentials",
190+
authType: auth.AUTH_TYPE_PAT,
191+
setupMocks: func() {
192+
engine.EXPECT().GetConfiguration().Return(configuration.NewWithOpts()).AnyTimes()
193+
engine.EXPECT().InvokeWithConfig(gomock.Any(), gomock.Any()).Return(nil, nil)
194+
},
195+
},
196+
{
197+
name: "Token flow clears all credentials",
198+
authType: auth.AUTH_TYPE_TOKEN,
199+
setupMocks: func() {
200+
engine.EXPECT().InvokeWithConfig(gomock.Any(), gomock.Any()).Return(nil, nil)
201+
},
202+
},
203+
}
204+
205+
for _, tc := range testCases {
206+
t.Run(tc.name, func(t *testing.T) {
207+
config := configuration.NewWithOpts()
208+
config.Set(authTypeParameter, tc.authType)
209+
if tc.authType == auth.AUTH_TYPE_PAT {
210+
config.Set(ConfigurationNewAuthenticationToken, "snyk_uat.12345678.abcdefg-hijklmnop.qrstuvwxyz-123456")
211+
}
212+
213+
// Set existing tokens that should be cleared
214+
config.Set(auth.CONFIG_KEY_OAUTH_TOKEN, "existing-oauth-token")
215+
config.Set(configuration.AUTHENTICATION_TOKEN, "existing-auth-token")
216+
217+
mockInvocationContext := mocks.NewMockInvocationContext(mockCtl)
218+
mockInvocationContext.EXPECT().GetConfiguration().Return(config).AnyTimes()
219+
mockInvocationContext.EXPECT().GetEnhancedLogger().Return(&logger).AnyTimes()
220+
mockInvocationContext.EXPECT().GetAnalytics().Return(analytics).AnyTimes()
221+
222+
tc.setupMocks()
223+
224+
err := entryPointDI(mockInvocationContext, &logger, engine, authenticator)
225+
assert.NoError(t, err)
226+
227+
// Verify both tokens are cleared regardless of auth type
228+
assert.Empty(t, config.GetString(auth.CONFIG_KEY_OAUTH_TOKEN), "OAuth token should be cleared for %s flow", tc.authType)
229+
assert.Empty(t, config.GetString(configuration.AUTHENTICATION_TOKEN), "Authentication token should be cleared for %s flow", tc.authType)
230+
})
231+
}
232+
}
233+
166234
func Test_autodetectAuth(t *testing.T) {
167235
t.Run("in stable versions, token by default", func(t *testing.T) {
168236
expected := auth.AUTH_TYPE_OAUTH

0 commit comments

Comments
 (0)