Skip to content

Commit a1ae0af

Browse files
Merge pull request #1 from seanpedrick-case/dev
Added Gemini and AWS Bedrock compatibility. Gemma model. Now document redaction QA.
2 parents 0c818aa + 03afd76 commit a1ae0af

File tree

14 files changed

+1086
-546
lines changed

14 files changed

+1086
-546
lines changed

.dockerignore

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
*.pyc
2+
*.ipynb
3+
*.pdf
4+
*.spec
5+
*.toc
6+
*.csv
7+
*.bin
8+
bootstrapper.py
9+
build/*
10+
dist/*
11+
test/*
12+
config/*
13+
output/*
14+
input/*

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.zip filter=lfs diff=lfs merge=lfs -text

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,7 @@
88
bootstrapper.py
99
build/*
1010
dist/*
11-
test/*
11+
test/*
12+
config/*
13+
output/*
14+
input/*

app.py

Lines changed: 172 additions & 165 deletions
Large diffs are not rendered by default.

chatfuncs/auth.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,62 @@
1-
1+
#import os
22
import boto3
3-
from chatfuncs.helper_functions import get_or_create_env_var
4-
5-
client_id = get_or_create_env_var('AWS_CLIENT_ID', '') # This client id is borrowed from async gradio app client
6-
print(f'The value of AWS_CLIENT_ID is {client_id}')
3+
#import gradio as gr
4+
import hmac
5+
import hashlib
6+
import base64
7+
from chatfuncs.config import AWS_CLIENT_ID, AWS_CLIENT_SECRET, AWS_USER_POOL_ID, AWS_REGION
78

8-
user_pool_id = get_or_create_env_var('AWS_USER_POOL_ID', '')
9-
print(f'The value of AWS_USER_POOL_ID is {user_pool_id}')
9+
def calculate_secret_hash(client_id:str, client_secret:str, username:str):
10+
message = username + client_id
11+
dig = hmac.new(
12+
str(client_secret).encode('utf-8'),
13+
msg=str(message).encode('utf-8'),
14+
digestmod=hashlib.sha256
15+
).digest()
16+
secret_hash = base64.b64encode(dig).decode()
17+
return secret_hash
1018

11-
def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=client_id):
19+
def authenticate_user(username:str, password:str, user_pool_id:str=AWS_USER_POOL_ID, client_id:str=AWS_CLIENT_ID, client_secret:str=AWS_CLIENT_SECRET):
1220
"""Authenticates a user against an AWS Cognito user pool.
1321
1422
Args:
1523
user_pool_id (str): The ID of the Cognito user pool.
1624
client_id (str): The ID of the Cognito user pool client.
1725
username (str): The username of the user.
1826
password (str): The password of the user.
27+
client_secret (str): The client secret of the app client
1928
2029
Returns:
2130
bool: True if the user is authenticated, False otherwise.
2231
"""
2332

24-
client = boto3.client('cognito-idp') # Cognito Identity Provider client
33+
client = boto3.client('cognito-idp', region_name=AWS_REGION) # Cognito Identity Provider client
34+
35+
# Compute the secret hash
36+
secret_hash = calculate_secret_hash(client_id, client_secret, username)
2537

2638
try:
27-
response = client.initiate_auth(
39+
40+
if client_secret == '':
41+
response = client.initiate_auth(
42+
AuthFlow='USER_PASSWORD_AUTH',
43+
AuthParameters={
44+
'USERNAME': username,
45+
'PASSWORD': password,
46+
},
47+
ClientId=client_id
48+
)
49+
50+
else:
51+
response = client.initiate_auth(
2852
AuthFlow='USER_PASSWORD_AUTH',
2953
AuthParameters={
3054
'USERNAME': username,
3155
'PASSWORD': password,
56+
'SECRET_HASH': secret_hash
3257
},
3358
ClientId=client_id
34-
)
59+
)
3560

3661
# If successful, you'll receive an AuthenticationResult in the response
3762
if response.get('AuthenticationResult'):
@@ -44,5 +69,7 @@ def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=c
4469
except client.exceptions.UserNotFoundException:
4570
return False
4671
except Exception as e:
47-
print(f"An error occurred: {e}")
48-
return False
72+
out_message = f"An error occurred: {e}"
73+
print(out_message)
74+
raise Exception(out_message)
75+
return False

0 commit comments

Comments
 (0)