@@ -81,11 +81,10 @@ async def as_form(
8181
8282class UserResetPassword (BaseModel ):
8383 email : EmailStr
84- token : str
84+ token : Optional [ str ]
8585 new_password : str
8686 confirm_new_password : str
8787
88- # Use the factory with a different field name
8988 validate_password_strength = create_password_validator ("new_password" )
9089 validate_passwords_match = create_passwords_match_validator (
9190 "new_password" , "confirm_new_password" )
@@ -94,12 +93,16 @@ class UserResetPassword(BaseModel):
9493 async def as_form (
9594 cls ,
9695 email : EmailStr = Form (...),
97- token : str = Form (... ),
96+ token : str = Form (None ),
9897 new_password : str = Form (...),
9998 confirm_new_password : str = Form (...)
10099 ):
101- return cls (email = email , token = token ,
102- new_password = new_password , confirm_new_password = confirm_new_password )
100+ return cls (
101+ email = email ,
102+ token = token ,
103+ new_password = new_password ,
104+ confirm_new_password = confirm_new_password
105+ )
103106
104107
105108# --- DB Request and Response Models ---
@@ -256,8 +259,39 @@ async def forgot_password(
256259@router .post ("/reset_password" )
257260async def reset_password (
258261 user : UserResetPassword = Depends (UserResetPassword .as_form ),
262+ tokens : tuple [Optional [str ], Optional [str ]] = Depends (oauth2_scheme_cookie ),
259263 session : Session = Depends (get_session )
260264):
265+ access_token , _ = tokens
266+
267+ # Handle authenticated user
268+ if access_token :
269+ try :
270+ decoded_token = validate_token (access_token )
271+ if decoded_token and decoded_token .get ("sub" ) == user .email :
272+ # User is authenticated and changing their own password
273+ db_user = session .exec (select (User ).where (
274+ User .email == user .email )).first ()
275+ if not db_user :
276+ raise HTTPException (status_code = 404 , detail = "User not found" )
277+
278+ # Update password
279+ if db_user .password :
280+ db_user .password .hashed_password = get_password_hash (user .new_password )
281+ else :
282+ db_user .password = UserPassword (
283+ hashed_password = get_password_hash (user .new_password )
284+ )
285+ session .commit ()
286+ return RedirectResponse (url = "/settings" , status_code = 303 )
287+
288+ except Exception as e :
289+ logger .error (f"Error validating token: { e } " )
290+
291+ # Handle unauthenticated user with reset token
292+ if not user .token :
293+ raise HTTPException (status_code = 400 , detail = "Reset token required for unauthenticated password reset" )
294+
261295 authorized_user , reset_token = get_user_from_reset_token (
262296 user .email , user .token , session )
263297
@@ -270,16 +304,13 @@ async def reset_password(
270304 user .new_password
271305 )
272306 else :
273- logger .warning (
274- "User password not found during password reset; creating new password for user" )
275307 authorized_user .password = UserPassword (
276308 hashed_password = get_password_hash (user .new_password )
277309 )
278310
279311 reset_token .used = True
280312 session .commit ()
281- session .refresh (authorized_user )
282-
313+
283314 return RedirectResponse (url = "/login" , status_code = 303 )
284315
285316
0 commit comments