@@ -103,12 +103,12 @@ func pathLogin(b *backend) []*framework.Path {
103103 }
104104}
105105
106- func (b * backend ) getJWTFromOauthPasswordGrant (ctx context.Context , req * logical.Request , username , password string ) (* TokenResponse , error ) {
106+ func (b * backend ) makeOauthRequest (ctx context.Context , req * logical.Request , data url. Values ) (* TokenResponse , error ) {
107107 config , err := b .Config (ctx , req .Storage )
108108 if err != nil {
109109 return nil , err
110110 }
111- if config .OauthClientID == "" || config .OauthEndpoint == "" || config .OauthClientSecret == "" || config .OauthCACert == "" {
111+ if config .OauthEndpoint == "" || config . OauthClientID == "" || config .OauthResource == "" || config .OauthClientSecret == "" || config .OauthCACert == "" {
112112 return nil , fmt .Errorf ("missing configuration elements" )
113113 }
114114
@@ -126,18 +126,13 @@ func (b *backend) getJWTFromOauthPasswordGrant(ctx context.Context, req *logical
126126 },
127127 }
128128
129- data := url.Values {}
130- data .Add ("grant_type" , "password" )
131129 data .Add ("client_id" , config .OauthClientID )
132130 data .Add ("client_secret" , config .OauthClientSecret )
133131 if config .OauthResource != "" {
134132 data .Add ("resource" , config .OauthResource )
135133 }
136- if config .ADDomain != "" {
137- data .Add ("username" , fmt .Sprintf ("%s\\ %s" , config .ADDomain , username ))
138- }
139- data .Add ("password" , password )
140134 encoded := []byte (data .Encode ())
135+
141136 request , err := http .NewRequest ("POST" , config .OauthEndpoint , bytes .NewBuffer (encoded ))
142137 if err != nil {
143138 return nil , err
@@ -163,50 +158,31 @@ func (b *backend) getJWTFromOauthPasswordGrant(ctx context.Context, req *logical
163158 }
164159
165160 return & payload , nil
161+
166162}
167163
168- func (b * backend ) getJWTFromOauthRefresh (ctx context.Context , req * logical.Request , refreshToken string ) (* TokenResponse , error ) {
164+ func (b * backend ) getJWTFromOauthPasswordGrant (ctx context.Context , req * logical.Request , username , password string ) (* TokenResponse , error ) {
169165 config , err := b .Config (ctx , req .Storage )
170166 if err != nil {
171167 return nil , err
172168 }
173- if config .OauthClientID == "" || config .OauthEndpoint == "" || config .OauthClientSecret == "" {
174- return nil , fmt .Errorf ("missing configuration elements" )
175- }
176- caCertPool := x509 .NewCertPool ()
177- httpClient := & pester.Client {
178- Transport : & http.Transport {
179- TLSClientConfig : & tls.Config {
180- RootCAs : caCertPool ,
181- },
182- },
169+
170+ data := url.Values {}
171+ data .Add ("grant_type" , "password" )
172+ if config .ADDomain != "" {
173+ data .Add ("username" , fmt .Sprintf ("%s\\ %s" , config .ADDomain , username ))
174+ } else {
175+ data .Add ("username" , username )
183176 }
177+ data .Add ("password" , password )
178+ return b .makeOauthRequest (ctx , req , data )
179+ }
180+
181+ func (b * backend ) getJWTFromOauthRefresh (ctx context.Context , req * logical.Request , refreshToken string ) (* TokenResponse , error ) {
184182 data := url.Values {}
185183 data .Add ("grant_type" , "refresh_token" )
186- data .Add ("client_id" , config .OauthClientID )
187- data .Add ("client_secret" , config .OauthClientSecret )
188184 data .Add ("refresh_token" , refreshToken )
189- resp , err := httpClient .PostForm (config .OauthEndpoint , data )
190- if err != nil {
191- return nil , err
192- }
193- if resp == nil {
194- return nil , fmt .Errorf ("no response from Oauth server" )
195- }
196- if resp .StatusCode != http .StatusOK {
197- return nil , fmt .Errorf ("refresh token is not valid - likely expired" )
198- }
199- var htmlData []byte
200- htmlData , err = ioutil .ReadAll (resp .Body )
201- if err != nil {
202- return nil , err
203- }
204- var payload TokenResponse
205- err = json .Unmarshal (htmlData , & payload )
206- if err != nil {
207- return nil , err
208- }
209- return & payload , nil
185+ return b .makeOauthRequest (ctx , req , data )
210186}
211187
212188func (b * backend ) pathLoginAliasLookahead (ctx context.Context , req * logical.Request , data * framework.FieldData ) (* logical.Response , error ) {
@@ -223,11 +199,18 @@ func (b *backend) pathLoginAliasLookahead(ctx context.Context, req *logical.Requ
223199 } else {
224200 jwtMappings = jwtMappingsResp
225201 }
202+ if jwtMappings == nil {
203+ return nil , fmt .Errorf ("unable to map claims" )
204+ }
205+ subject , ok := jwtMappings .Claims [config .SubjectClaim ]
206+ if ! ok {
207+ return nil , fmt .Errorf ("unable to find subject" )
208+ }
226209
227210 return & logical.Response {
228211 Auth : & logical.Auth {
229212 Alias : & logical.Alias {
230- Name : jwtMappings . Claims [ config . SubjectClaim ] .(string ),
213+ Name : subject .(string ),
231214 },
232215 },
233216 }, nil
@@ -263,18 +246,28 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
263246 if err != nil {
264247 return nil , err
265248 }
266- subject := jwtMappings .Claims [config .SubjectClaim ].(string )
267- claims := jwtMappings .Claims [config .RoleClaim ]
249+ if jwtMappings == nil {
250+ return nil , fmt .Errorf ("unable to map claims" )
251+ }
252+ subject , ok := jwtMappings .Claims [config .SubjectClaim ]
253+ if ! ok {
254+ return nil , fmt .Errorf ("unable to find subject" )
255+ }
256+
257+ claims , ok := jwtMappings .Claims [config .RoleClaim ]
258+ if ! ok {
259+ return nil , fmt .Errorf ("unable to find roles" )
260+ }
268261 resp := & logical.Response {
269262 Auth : & logical.Auth {
270263 InternalData : map [string ]interface {}{
271264 "token" : token ,
272265 "refresh_token" : refreshToken ,
273266 },
274- DisplayName : subject ,
267+ DisplayName : subject .( string ) ,
275268 Policies : jwtMappings .Policies ,
276269 Metadata : map [string ]string {
277- "username" : subject ,
270+ "username" : subject .( string ) ,
278271 "jwt" : token ,
279272 "roles" : fmt .Sprintf ("%v" , claims ),
280273 },
@@ -283,7 +276,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
283276 Renewable : true ,
284277 },
285278 Alias : & logical.Alias {
286- Name : jwtMappings . Claims [ config . SubjectClaim ] .(string ),
279+ Name : subject .(string ),
287280 },
288281 },
289282 }
@@ -305,14 +298,15 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
305298 refreshToken , ok := req .Auth .InternalData ["refresh_token" ]
306299 if ok && refreshToken != "" {
307300 tokenResponse , err := b .getJWTFromOauthRefresh (ctx , req , refreshToken .(string ))
301+
308302 if err != nil {
309303 tokenRaw , ok := req .Auth .InternalData ["token" ]
310304 if ! ok {
311305 return nil , fmt .Errorf ("token created in previous version of Vault cannot be validated properly at renewal time" )
312306 }
313307 token = tokenRaw .(string )
314308 } else {
315- token = tokenResponse .IDToken
309+ token = tokenResponse .AccessToken
316310 }
317311 }
318312
@@ -339,11 +333,26 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
339333 }
340334
341335 resp .Auth .GroupAliases = nil
342- roles := jwtMappings .Claims [config .RoleClaim ].([]string )
343- for _ , role := range roles {
336+ if jwtMappings == nil {
337+ return nil , fmt .Errorf ("unable to map claims" )
338+ }
339+
340+ claims , ok := jwtMappings .Claims [config .RoleClaim ].([]interface {})
341+ if ! ok {
342+ roleName , ok := jwtMappings .Claims [config .RoleClaim ].(string )
343+ if ! ok {
344+ return nil , fmt .Errorf ("unable to find roles" )
345+ }
344346 resp .Auth .GroupAliases = append (resp .Auth .GroupAliases , & logical.Alias {
345- Name : role ,
347+ Name : roleName ,
346348 })
349+ } else {
350+ for _ , role := range claims {
351+ roleName := role .(string )
352+ resp .Auth .GroupAliases = append (resp .Auth .GroupAliases , & logical.Alias {
353+ Name : roleName ,
354+ })
355+ }
347356 }
348357
349358 return resp , nil
0 commit comments