1-
1+ #import os
22import boto3
3- from chatfuncs .helper_functions import get_or_create_env_var
4-
5- client_id = get_or_create_env_var ('AWS_CLIENT_ID' , '' ) # This client id is borrowed from async gradio app client
6- print (f'The value of AWS_CLIENT_ID is { client_id } ' )
3+ #import gradio as gr
4+ import hmac
5+ import hashlib
6+ import base64
7+ from chatfuncs .config import AWS_CLIENT_ID , AWS_CLIENT_SECRET , AWS_USER_POOL_ID , AWS_REGION
78
8- user_pool_id = get_or_create_env_var ('AWS_USER_POOL_ID' , '' )
9- print (f'The value of AWS_USER_POOL_ID is { user_pool_id } ' )
9+ def calculate_secret_hash (client_id :str , client_secret :str , username :str ):
10+ message = username + client_id
11+ dig = hmac .new (
12+ str (client_secret ).encode ('utf-8' ),
13+ msg = str (message ).encode ('utf-8' ),
14+ digestmod = hashlib .sha256
15+ ).digest ()
16+ secret_hash = base64 .b64encode (dig ).decode ()
17+ return secret_hash
1018
11- def authenticate_user (username , password , user_pool_id = user_pool_id , client_id = client_id ):
19+ def authenticate_user (username : str , password : str , user_pool_id : str = AWS_USER_POOL_ID , client_id : str = AWS_CLIENT_ID , client_secret : str = AWS_CLIENT_SECRET ):
1220 """Authenticates a user against an AWS Cognito user pool.
1321
1422 Args:
1523 user_pool_id (str): The ID of the Cognito user pool.
1624 client_id (str): The ID of the Cognito user pool client.
1725 username (str): The username of the user.
1826 password (str): The password of the user.
27+ client_secret (str): The client secret of the app client
1928
2029 Returns:
2130 bool: True if the user is authenticated, False otherwise.
2231 """
2332
24- client = boto3 .client ('cognito-idp' ) # Cognito Identity Provider client
33+ client = boto3 .client ('cognito-idp' , region_name = AWS_REGION ) # Cognito Identity Provider client
34+
35+ # Compute the secret hash
36+ secret_hash = calculate_secret_hash (client_id , client_secret , username )
2537
2638 try :
27- response = client .initiate_auth (
39+
40+ if client_secret == '' :
41+ response = client .initiate_auth (
42+ AuthFlow = 'USER_PASSWORD_AUTH' ,
43+ AuthParameters = {
44+ 'USERNAME' : username ,
45+ 'PASSWORD' : password ,
46+ },
47+ ClientId = client_id
48+ )
49+
50+ else :
51+ response = client .initiate_auth (
2852 AuthFlow = 'USER_PASSWORD_AUTH' ,
2953 AuthParameters = {
3054 'USERNAME' : username ,
3155 'PASSWORD' : password ,
56+ 'SECRET_HASH' : secret_hash
3257 },
3358 ClientId = client_id
34- )
59+ )
3560
3661 # If successful, you'll receive an AuthenticationResult in the response
3762 if response .get ('AuthenticationResult' ):
@@ -44,5 +69,7 @@ def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=c
4469 except client .exceptions .UserNotFoundException :
4570 return False
4671 except Exception as e :
47- print (f"An error occurred: { e } " )
48- return False
72+ out_message = f"An error occurred: { e } "
73+ print (out_message )
74+ raise Exception (out_message )
75+ return False
0 commit comments