diff --git a/webauthn/login.go b/webauthn/login.go index 89ff5f87..d0453d12 100644 --- a/webauthn/login.go +++ b/webauthn/login.go @@ -90,7 +90,7 @@ func (webauthn *WebAuthn) beginLogin(userID []byte, allowedCredentials []protoco } session = &SessionData{ - Challenge: challenge.String(), + Challenge: assertion.Response.Challenge.String(), RelyingPartyID: assertion.Response.RelyingPartyID, UserID: userID, AllowedCredentialIDs: assertion.Response.GetAllowedCredentialIDs(), diff --git a/webauthn/login_test.go b/webauthn/login_test.go index a869cf9e..18e8cddf 100644 --- a/webauthn/login_test.go +++ b/webauthn/login_test.go @@ -108,3 +108,70 @@ func TestWithLoginRelyingPartyID(t *testing.T) { }) } } + +func TestCustomerChallenge(t *testing.T) { + customerChallenge := "hello world" + nonCustomerChallenge := make(protocol.URLEncodedBase64, 0) + testCases := []struct { + name string + have *Config + opts []LoginOption + expectedChallenge func() protocol.URLEncodedBase64 + err string + }{ + { + name: "NonCustomerChallenge", + have: &Config{ + RPID: "https://example.com", + RPDisplayName: "Test Non-Customer Challenge", + RPOrigins: []string{"https://example.com"}, + }, + opts: []LoginOption{ + func(opt *protocol.PublicKeyCredentialRequestOptions) { + nonCustomerChallenge = opt.Challenge + }, + }, + expectedChallenge: func() protocol.URLEncodedBase64 { + return nonCustomerChallenge + }, + }, + { + name: "CustomerChallenge", + have: &Config{ + RPID: "https://example.com", + RPDisplayName: "Test Customer Challenge", + RPOrigins: []string{"https://example.com"}, + }, + opts: []LoginOption{ + func(opt *protocol.PublicKeyCredentialRequestOptions) { + opt.Challenge = protocol.URLEncodedBase64(customerChallenge) + }, + }, + expectedChallenge: func() protocol.URLEncodedBase64 { + return []byte(customerChallenge) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + w, err := New(tc.have) + assert.NoError(t, err) + + user := &defaultUser{ + credentials: []Credential{ + {}, + }, + } + + creation, _, err := w.BeginLogin(user, tc.opts...) + if tc.err != "" { + assert.EqualError(t, err, tc.err) + } else { + assert.NoError(t, err) + require.NotNil(t, creation) + assert.Equal(t, tc.expectedChallenge(), creation.Response.Challenge) + } + }) + } +}