Skip to content

Commit e072657

Browse files
Merge pull request #42 from basedosdados/feat/refresh-tokens
feat: refresh tokens
2 parents d7dd8cd + 542a192 commit e072657

File tree

6 files changed

+283
-119
lines changed

6 files changed

+283
-119
lines changed

frontend/api/api_client.py

Lines changed: 123 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,101 @@
1+
from datetime import datetime, timedelta
12
from typing import Iterator
23
from uuid import UUID, uuid4
34

45
import httpx
6+
import jwt
7+
import streamlit as st
58
from loguru import logger
69

710
from frontend.datatypes import (EventData, MessagePair, StreamEvent, Thread,
811
UserMessage)
12+
from frontend.exceptions import SessionExpiredException
913

1014

1115
class APIClient:
1216
def __init__(self, base_url: str):
1317
self.base_url = base_url
1418
self.logger = logger.bind(classname=self.__class__.__name__)
1519

16-
def authenticate(self, email: str, password: str) -> tuple[str|None, str]:
20+
def _is_token_expired(self, token: str) -> bool:
21+
"""Check if a JWT token is expired or about to expire (within 1 minute).
22+
23+
Args:
24+
token (str): The token.
25+
26+
Returns:
27+
bool: Whether the token is expired or not.
28+
"""
29+
if not token:
30+
return True
31+
32+
try:
33+
payload: dict = jwt.decode(token, options={"verify_signature": False})
34+
exp = payload.get("exp")
35+
if not exp:
36+
return True
37+
expiration = datetime.fromtimestamp(exp)
38+
return datetime.now() >= expiration - timedelta(seconds=60)
39+
except Exception:
40+
return True
41+
42+
def _refresh_access_token(self, refresh_token: str) -> str:
43+
"""Refresh the access token using a refresh token.
44+
45+
Args:
46+
refresh_token (str): The refresh token.
47+
48+
Returns:
49+
str: A refreshed access token.
50+
"""
51+
response = httpx.post(
52+
f"{self.base_url}/chatbot/token/refresh/",
53+
json={"refresh": refresh_token}
54+
)
55+
response.raise_for_status()
56+
access_token = response.json()["access"]
57+
return access_token
58+
59+
def _get_headers(self, access_token: str, refresh_token: str) -> dict[str, str]:
60+
"""Get authorization headers, refreshing access token as needed.
61+
62+
Args:
63+
access_token (str): The access token.
64+
refresh_token (str): The refresh token.
65+
66+
Raises:
67+
SessionExpiredException: If refresh token is expired (401).
68+
69+
Returns:
70+
dict[str, str]: The authorization headers,
71+
"""
72+
if self._is_token_expired(access_token):
73+
self.logger.info("[AUTH] Access token expired, refreshing...")
74+
try:
75+
access_token = self._refresh_access_token(refresh_token)
76+
st.session_state["access_token"] = access_token
77+
self.logger.success("[AUTH] Access token refreshed successfully")
78+
except httpx.HTTPStatusError as e:
79+
if e.response.status_code == httpx.codes.UNAUTHORIZED:
80+
self.logger.info("[AUTH] Refresh token expired")
81+
raise SessionExpiredException from e
82+
raise # Re-raise other HTTP errors
83+
84+
return {"Authorization": f"Bearer {access_token}"}
85+
86+
def authenticate(self, email: str, password: str) -> tuple[str|None, str|None, str]:
1787
"""Send a post request to the authentication endpoint.
1888
1989
Args:
2090
email (str): The email.
2191
password (str): The password.
2292
2393
Returns:
24-
tuple[str|None, str]: A tuple containing the access token and a status message.
94+
tuple[str|None, str|None, str]:
95+
A tuple containing the access token, the refresh token and a status message.
2596
"""
2697
access_token = None
27-
98+
refresh_token = None
2899
message = "Ops! Ocorreu um erro durante o login. Por favor, tente novamente."
29100

30101
try:
@@ -38,28 +109,30 @@ def authenticate(self, email: str, password: str) -> tuple[str|None, str]:
38109
response.raise_for_status()
39110

40111
access_token = response.json().get("access")
112+
refresh_token = response.json().get("refresh")
41113

42-
if access_token:
43-
self.logger.success(f"[LOGIN] Successfully logged in")
114+
if access_token and refresh_token:
115+
self.logger.success(f"[AUTH] Successfully logged in")
44116
message = "Conectado com sucesso!"
45117
else:
46-
self.logger.error(f"[LOGIN] No access token returned")
118+
self.logger.error(f"[AUTH] No access and refresh tokens returned")
47119
except httpx.HTTPStatusError:
48120
if response.status_code == httpx.codes.UNAUTHORIZED:
49-
self.logger.warning(f"[LOGIN] Invalid credentials")
121+
self.logger.warning(f"[AUTH] Invalid credentials")
50122
message = "Usuário ou senha incorretos."
51123
else:
52-
self.logger.exception(f"[LOGIN] HTTP error:")
124+
self.logger.exception(f"[AUTH] HTTP error:")
53125
except Exception:
54-
self.logger.exception(f"[LOGIN] Login error:")
126+
self.logger.exception(f"[AUTH] Login error:")
55127

56-
return access_token, message
128+
return access_token, refresh_token, message
57129

58-
def create_thread(self, access_token: str, title: str) -> Thread|None:
130+
def create_thread(self, access_token: str, refresh_token: str, title: str) -> Thread|None:
59131
"""Create a thread.
60132
61133
Args:
62134
access_token (str): User access token.
135+
refresh_token (str): User refresh token.
63136
title (str): The thread title.
64137
65138
Returns:
@@ -71,21 +144,24 @@ def create_thread(self, access_token: str, title: str) -> Thread|None:
71144
response = httpx.post(
72145
url=f"{self.base_url}/chatbot/threads/",
73146
json={"title": title},
74-
headers={"Authorization": f"Bearer {access_token}"},
147+
headers=self._get_headers(access_token, refresh_token),
75148
)
76149
response.raise_for_status()
77150
thread = Thread(**response.json())
78151
self.logger.success(f"[THREAD] Thread created successfully for user {thread.account}")
79152
return thread
153+
except SessionExpiredException:
154+
raise
80155
except Exception:
81156
self.logger.exception(f"[THREAD] Error on thread creation:")
82157
return None
83158

84-
def get_threads(self, access_token: str) -> list[Thread]|None:
159+
def get_threads(self, access_token: str, refresh_token: str) -> list[Thread]|None:
85160
"""Get all threads from a user.
86161
87162
Args:
88163
access_token (str): User access token.
164+
refresh_token (str): User refresh token.
89165
90166
Returns:
91167
list[Thread]|None: A list of Thread objects if any thread was found. None otherwise.
@@ -95,39 +171,54 @@ def get_threads(self, access_token: str) -> list[Thread]|None:
95171
response = httpx.get(
96172
url=f"{self.base_url}/chatbot/threads/",
97173
params={"order_by": "created_at"},
98-
headers={"Authorization": f"Bearer {access_token}"}
174+
headers=self._get_headers(access_token, refresh_token)
99175
)
100176
response.raise_for_status()
101177
threads = [Thread(**thread) for thread in response.json()]
102178
self.logger.success(f"[THREAD] Threads retrieved successfully")
103179
return threads
180+
except SessionExpiredException:
181+
raise
104182
except Exception:
105183
self.logger.exception(f"[THREAD] Error on threads retrieval:")
106184
return None
107185

108-
def get_message_pairs(self, access_token: str, thread_id: UUID) -> list[MessagePair]|None:
186+
def get_message_pairs(self, access_token: str, refresh_token: str, thread_id: UUID) -> list[MessagePair]|None:
187+
"""Get all messages from a thread.
188+
189+
Args:
190+
access_token (str): User access token.
191+
refresh_token (str): User refresh token.
192+
thread_id (UUID): Thread unique identifier.
193+
194+
Returns:
195+
list[MessagePair]|None: A list of MessagePair objects if any message was found. None otherwise.
196+
"""
109197
self.logger.info(f"[MESSAGE] Retrieving message pairs for thread {thread_id}")
110198
try:
111199
response = httpx.get(
112200
url=f"{self.base_url}/chatbot/threads/{thread_id}/messages/",
113201
params={"order_by": "created_at"},
114-
headers={"Authorization": f"Bearer {access_token}"}
202+
headers=self._get_headers(access_token, refresh_token)
115203
)
116204
response.raise_for_status()
117205
message_pairs = [MessagePair(**pair) for pair in response.json()]
118206
self.logger.success(f"[MESSAGE] Message pairs retrieved successfully for thread {thread_id}")
119207
return message_pairs
208+
except SessionExpiredException:
209+
raise
120210
except Exception:
121211
self.logger.exception(f"[MESSAGE] Error on message pairs retrieval for thread {thread_id}:")
122212
return None
123213

124-
def send_message(self, access_token: str, message: str, thread_id: UUID) -> Iterator[StreamEvent]:
214+
def send_message(self, access_token: str, refresh_token: str, message: str, thread_id: UUID) -> Iterator[StreamEvent]:
125215
"""Send a user message and stream the assistant's response.
126216
127217
Args:
128-
access_token (str): The user's access token.
218+
access_token (str): User access token.
219+
refresh_token (str): User refresh token.
129220
message (str): The message sent by the user.
130-
thread_id (UUID): The unique identifier of the thread.
221+
thread_id (UUID):Thread unique identifier.
131222
132223
Yields:
133224
Iterator[StreamEvent]: Iterator of `StreamEvent` objects.
@@ -143,7 +234,7 @@ def send_message(self, access_token: str, message: str, thread_id: UUID) -> Iter
143234
with httpx.stream(
144235
method="POST",
145236
url=f"{self.base_url}/chatbot/threads/{thread_id}/messages/",
146-
headers={"Authorization": f"Bearer {access_token}"},
237+
headers=self._get_headers(access_token, refresh_token),
147238
json=user_message.model_dump(mode="json"),
148239
timeout=httpx.Timeout(5.0, read=300.0),
149240
) as response:
@@ -161,6 +252,8 @@ def send_message(self, access_token: str, message: str, thread_id: UUID) -> Iter
161252
stream_completed = True
162253

163254
yield event
255+
except SessionExpiredException:
256+
raise
164257
except httpx.ReadTimeout:
165258
self.logger.exception(f"[MESSAGE] Timeout error on sending user message:")
166259
error_message=(
@@ -194,11 +287,12 @@ def send_message(self, access_token: str, message: str, thread_id: UUID) -> Iter
194287
data=EventData(run_id=uuid4())
195288
)
196289

197-
def send_feedback(self, access_token: str, message_pair_id: UUID, rating: int, comments: str) -> bool:
290+
def send_feedback(self, access_token: str, refresh_token: str, message_pair_id: UUID, rating: int, comments: str) -> bool:
198291
"""Send a feedback.
199292
200293
Args:
201294
access_token (str): User access token.
295+
refresh_token (str): User refresh token.
202296
message_pair_id (UUID): The message pair unique identifier.
203297
rating (int): The rating (0 or 1).
204298
comments (str): The comments.
@@ -212,20 +306,23 @@ def send_feedback(self, access_token: str, message_pair_id: UUID, rating: int, c
212306
response = httpx.put(
213307
url=f"{self.base_url}/chatbot/message-pairs/{message_pair_id}/feedbacks/",
214308
json={"rating": rating, "comment": comments},
215-
headers={"Authorization": f"Bearer {access_token}"}
309+
headers=self._get_headers(access_token, refresh_token)
216310
)
217311
response.raise_for_status()
218312
self.logger.success(f"[FEEDBACK] Feedback sent successfully")
219313
return True
314+
except SessionExpiredException:
315+
raise
220316
except Exception:
221317
self.logger.exception(f"[FEEDBACK] Error on sending feedback:")
222318
return False
223319

224-
def delete_thread(self, access_token: str, thread_id: UUID) -> bool:
320+
def delete_thread(self, access_token: str, refresh_token: str, thread_id: UUID) -> bool:
225321
"""Soft delete a thread and hard delete all its checkpoints.
226322
227323
Args:
228324
access_token (str): User access token.
325+
refresh_token (str): User refresh token.
229326
thread_id (UUID): Thread unique identifier.
230327
231328
Returns:
@@ -236,12 +333,14 @@ def delete_thread(self, access_token: str, thread_id: UUID) -> bool:
236333
try:
237334
response = httpx.delete(
238335
url=f"{self.base_url}/chatbot/threads/{thread_id}/",
239-
headers={"Authorization": f"Bearer {access_token}"},
336+
headers=self._get_headers(access_token, refresh_token),
240337
timeout=httpx.Timeout(5.0, read=60.0)
241338
)
242339
response.raise_for_status()
243340
self.logger.success(f"[CLEAR] Assistant memory cleared successfully")
244341
return True
342+
except SessionExpiredException:
343+
raise
245344
except Exception:
246345
self.logger.exception("[CLEAR] Error on clearing assistant memory:")
247346
return False

0 commit comments

Comments
 (0)