-
Notifications
You must be signed in to change notification settings - Fork 579
FEAT new target class for AWS Bedrock Anthropic Claude models #699
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kmarsh77
wants to merge
49
commits into
Azure:main
Choose a base branch
from
kmarsh77:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
ed1e3ff
Adding AWS Bedrock Anthropic Claude target class
kmarsh77 63a9b2e
Adding unit tests for AWSBedrockClaudeTarget class
kmarsh77 5de223d
Add optional aws dependency (boto3)
kmarsh77 ac87b28
Update aws_bedrock_claude_target.py
kmarsh77 45785d6
Adding bedrock claude target class
kmarsh77 f7a8767
Update __init__.py for new target classes
kmarsh77 57252d0
Unit test for AWSBedrockClaudeChatTarget
kmarsh77 f254145
Delete pyrit/prompt_target/aws_bedrock_claude_target.py
kmarsh77 f0bc2bc
Update __init__.py
kmarsh77 2ebb519
Delete tests/unit/test_aws_bedrock_claude_target.py
kmarsh77 258f287
Update aws_bedrock_claude_chat_target.py
kmarsh77 408c308
Update test_aws_bedrock_claude_chat_target.py
kmarsh77 5865ac6
Update pyproject.toml
kmarsh77 01addd1
Update aws_bedrock_claude_chat_target.py
kmarsh77 627396a
Update pyrit/prompt_target/aws_bedrock_claude_chat_target.py
kmarsh77 59dcb7e
Update pyrit/prompt_target/aws_bedrock_claude_chat_target.py
kmarsh77 b5c4924
Update aws_bedrock_claude_chat_target.py
kmarsh77 5d8d7e0
Update aws_bedrock_claude_chat_target.py
kmarsh77 185bcff
Update aws_bedrock_claude_chat_target.py
kmarsh77 2630b43
Update test_aws_bedrock_claude_chat_target.py
kmarsh77 7bcb075
Merge branch 'main' into main
kmarsh77 6d0485d
Merge branch 'Azure:main' into main
kmarsh77 0e0e300
Updates to address complaints from pre-commit hooks
kmarsh77 187cb16
Merge branch 'main' into main
romanlutz 3fd876b
Merge branch 'main' into main
romanlutz ef3ef17
Merge branch 'main' into main
romanlutz 6d531d5
Update pyrit/prompt_target/aws_bedrock_claude_chat_target.py
romanlutz e7c3c54
Adding exceptions for when boto3 isn't installed
kmarsh77 8ddb596
Adding exceptions for when boto3 isn't installed
kmarsh77 b80cbda
Merge branch 'main' of https://github.com/kmarsh77/PyRIT
kmarsh77 b772a9c
Adding noqa statements to pass pre-commit checks
kmarsh77 e4b10d3
Merge branch 'Azure:main' into main
kmarsh77 d88919a
Update tests/unit/test_aws_bedrock_claude_chat_target.py
romanlutz fbf6a86
Merge branch 'main' into main
romanlutz ee1a220
Fixing merge conflict in pyproject.toml
kmarsh77 294cdc9
changing import error message
kmarsh77 0e84c64
Merge branch 'Azure:main' into main
kmarsh77 0915d70
Fixed invalid converted_value_data_type in test_aws_bedrock_claude_ch…
kmarsh77 37e92b5
Merge branch 'Azure:main' into main
kmarsh77 c913bd4
moving test_aws_bedrock_claude_chat_target.py to tests/unit/target fo…
kmarsh77 8bfed3d
Adding ignore statements after test_send_prompt_async and test_comple…
kmarsh77 9d3059b
Adding ignore after boto3 use in aws_bedrock_claude_chat_target.py
kmarsh77 b5f5df0
Removing ignore statements
kmarsh77 5541705
removing ignore statements
kmarsh77 e827cb4
putting boto3.client inside try statement
kmarsh77 78af630
fixing
kmarsh77 590167e
Moving boto3 import to within _complete_chat_async
kmarsh77 3d83b84
Merge branch 'Azure:main' into main
kmarsh77 b7fa256
Merge branch 'Azure:main' into main
kmarsh77 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import asyncio | ||
import base64 | ||
import json | ||
import logging | ||
from typing import MutableSequence, Optional | ||
|
||
from pyrit.chat_message_normalizer import ChatMessageNop, ChatMessageNormalizer | ||
from pyrit.models import ( | ||
ChatMessageListDictContent, | ||
PromptRequestResponse, | ||
construct_response_from_request, | ||
) | ||
from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AWSBedrockClaudeChatTarget(PromptChatTarget): | ||
""" | ||
This class initializes an AWS Bedrock target for any of the Anthropic Claude models. | ||
Local AWS credentials (typically stored in ~/.aws) are used for authentication. | ||
See the following for more information: | ||
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html | ||
|
||
Parameters: | ||
model_id (str): The model ID for target claude model | ||
max_tokens (int): maximum number of tokens to generate | ||
temperature (float, optional): The amount of randomness injected into the response. | ||
top_p (float, optional): Use nucleus sampling | ||
top_k (int, optional): Only sample from the top K options for each subsequent token | ||
enable_ssl_verification (bool, optional): whether or not to perform SSL certificate verification | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
model_id: str, | ||
max_tokens: int, | ||
temperature: Optional[float] = None, | ||
top_p: Optional[float] = None, | ||
top_k: Optional[int] = None, | ||
enable_ssl_verification: bool = True, | ||
chat_message_normalizer: ChatMessageNormalizer = ChatMessageNop(), | ||
max_requests_per_minute: Optional[int] = None, | ||
) -> None: | ||
super().__init__(max_requests_per_minute=max_requests_per_minute) | ||
|
||
self._model_id = model_id | ||
self._max_tokens = max_tokens | ||
self._temperature = temperature | ||
self._top_p = top_p | ||
self._top_k = top_k | ||
self._enable_ssl_verification = enable_ssl_verification | ||
self.chat_message_normalizer = chat_message_normalizer | ||
|
||
self._system_prompt = "" | ||
|
||
self._valid_image_types = ["jpeg", "png", "webp", "gif"] | ||
|
||
@limit_requests_per_minute | ||
async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: | ||
|
||
self._validate_request(prompt_request=prompt_request) | ||
request_piece = prompt_request.request_pieces[0] | ||
|
||
prompt_req_res_entries = self._memory.get_conversation(conversation_id=request_piece.conversation_id) | ||
prompt_req_res_entries.append(prompt_request) | ||
|
||
logger.info(f"Sending the following prompt to the prompt target: {prompt_request}") | ||
|
||
messages = await self._build_chat_messages(prompt_req_res_entries) | ||
|
||
response = await self._complete_chat_async(messages=messages) | ||
|
||
response_entry = construct_response_from_request(request=request_piece, response_text_pieces=[response]) | ||
|
||
return response_entry | ||
|
||
def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: | ||
converted_prompt_data_types = [ | ||
request_piece.converted_value_data_type for request_piece in prompt_request.request_pieces | ||
] | ||
|
||
for prompt_data_type in converted_prompt_data_types: | ||
if prompt_data_type not in ["text", "image_path"]: | ||
raise ValueError("This target only supports text and image_path.") | ||
|
||
async def _complete_chat_async(self, messages: list[ChatMessageListDictContent]) -> str: | ||
try: | ||
import boto3 # noqa: F401 | ||
from botocore.exceptions import ClientError # noqa: F401 | ||
except ModuleNotFoundError as e: | ||
logger.error("Could not import boto. You may need to install it via 'pip install pyrit[all] or pyrit[aws]'") | ||
raise e | ||
|
||
brt = boto3.client( | ||
service_name="bedrock-runtime", region_name="us-east-1", verify=self._enable_ssl_verification | ||
) | ||
|
||
native_request = self._construct_request_body(messages) | ||
|
||
request = json.dumps(native_request) | ||
|
||
try: | ||
response = await asyncio.to_thread(brt.invoke_model, modelId=self._model_id, body=request) | ||
except (ClientError, Exception) as e: | ||
raise ValueError(f"ERROR: Can't invoke '{self._model_id}'. Reason: {e}") | ||
|
||
model_response = json.loads(response["body"].read()) | ||
|
||
answer = model_response["content"][0]["text"] | ||
|
||
logger.info(f'Received the following response from the prompt target "{answer}"') | ||
return answer | ||
|
||
def _convert_local_image_to_base64(self, image_path: str) -> str: | ||
with open(image_path, "rb") as image_file: | ||
encoded_string = base64.b64encode(image_file.read()) | ||
return encoded_string.decode() | ||
|
||
async def _build_chat_messages( | ||
self, prompt_req_res_entries: MutableSequence[PromptRequestResponse] | ||
) -> list[ChatMessageListDictContent]: | ||
chat_messages: list[ChatMessageListDictContent] = [] | ||
for prompt_req_resp_entry in prompt_req_res_entries: | ||
prompt_request_pieces = prompt_req_resp_entry.request_pieces | ||
|
||
content = [] | ||
role = None | ||
for prompt_request_piece in prompt_request_pieces: | ||
role = prompt_request_piece.role | ||
if role == "system": | ||
# Bedrock doesn't allow a message with role==system, | ||
# but it does let you specify system role in a param | ||
self._system_prompt = prompt_request_piece.converted_value | ||
elif prompt_request_piece.converted_value_data_type == "text": | ||
entry = {"type": "text", "text": prompt_request_piece.converted_value} | ||
content.append(entry) | ||
elif prompt_request_piece.converted_value_data_type == "image_path": | ||
image_type = prompt_request_piece.converted_value.split(".")[-1] | ||
if image_type not in self._valid_image_types: | ||
raise ValueError( | ||
f"""Image file {prompt_request_piece.converted_value} must | ||
have valid extension of .jpeg, .png, .webp, or .gif""" | ||
) | ||
|
||
data_base64_encoded = self._convert_local_image_to_base64(prompt_request_piece.converted_value) | ||
media_type = "image/" + image_type | ||
entry = { | ||
"type": "image", | ||
"source": { | ||
"type": "base64", | ||
"media_type": media_type, | ||
"data": data_base64_encoded, | ||
}, # type: ignore | ||
} | ||
content.append(entry) | ||
else: | ||
raise ValueError( | ||
f"Multimodal data type {prompt_request_piece.converted_value_data_type} is not yet supported." | ||
romanlutz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
if not role: | ||
raise ValueError("No role could be determined from the prompt request pieces.") | ||
|
||
chat_message = ChatMessageListDictContent(role=role, content=content) | ||
chat_messages.append(chat_message) | ||
return chat_messages | ||
|
||
def _construct_request_body(self, messages_list: list[ChatMessageListDictContent]) -> dict: | ||
content = [] | ||
|
||
for message in messages_list: | ||
if message.role != "system": | ||
entry = {"role": message.role, "content": message.content} | ||
content.append(entry) | ||
|
||
data = { | ||
"anthropic_version": "bedrock-2023-05-31", | ||
romanlutz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"max_tokens": self._max_tokens, | ||
"system": self._system_prompt, | ||
"messages": content, | ||
} | ||
|
||
if self._temperature: | ||
data["temperature"] = self._temperature | ||
if self._top_p: | ||
data["top_p"] = self._top_p | ||
if self._top_k: | ||
data["top_k"] = self._top_k | ||
|
||
return data | ||
|
||
def is_json_response_supported(self) -> bool: | ||
"""Indicates that this target supports JSON response format.""" | ||
return False |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import json | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
|
||
from pyrit.models import ( | ||
ChatMessageListDictContent, | ||
PromptRequestPiece, | ||
PromptRequestResponse, | ||
) | ||
from pyrit.prompt_target.aws_bedrock_claude_chat_target import ( | ||
AWSBedrockClaudeChatTarget, | ||
) | ||
|
||
|
||
def is_boto3_installed(): | ||
try: | ||
import boto3 # noqa: F401 | ||
|
||
return True | ||
except ModuleNotFoundError: | ||
return False | ||
|
||
|
||
@pytest.fixture | ||
def aws_target() -> AWSBedrockClaudeChatTarget: | ||
return AWSBedrockClaudeChatTarget( | ||
model_id="anthropic.claude-v2", | ||
max_tokens=100, | ||
temperature=0.7, | ||
top_p=0.9, | ||
top_k=50, | ||
enable_ssl_verification=True, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def mock_prompt_request(): | ||
request_piece = PromptRequestPiece( | ||
role="user", original_value="Hello, Claude!", converted_value="Hello, how are you?" | ||
) | ||
return PromptRequestResponse(request_pieces=[request_piece]) | ||
|
||
|
||
@pytest.mark.skipif(not is_boto3_installed(), reason="boto3 is not installed") | ||
@pytest.mark.asyncio | ||
async def test_send_prompt_async(aws_target, mock_prompt_request): | ||
with patch("boto3.client", new_callable=MagicMock) as mock_boto: | ||
mock_client = mock_boto.return_value | ||
mock_client.invoke_model.return_value = { | ||
"body": MagicMock(read=MagicMock(return_value=json.dumps({"content": [{"text": "I'm good, thanks!"}]}))) | ||
} | ||
|
||
response = await aws_target.send_prompt_async(prompt_request=mock_prompt_request) | ||
|
||
assert response.request_pieces[0].converted_value == "I'm good, thanks!" | ||
|
||
|
||
@pytest.mark.skipif(not is_boto3_installed(), reason="boto3 is not installed") | ||
@pytest.mark.asyncio | ||
async def test_validate_request_valid(aws_target, mock_prompt_request): | ||
aws_target._validate_request(prompt_request=mock_prompt_request) | ||
|
||
|
||
@pytest.mark.skipif(not is_boto3_installed(), reason="boto3 is not installed") | ||
@pytest.mark.asyncio | ||
async def test_validate_request_invalid_data_type(aws_target): | ||
request_pieces = [ | ||
PromptRequestPiece( | ||
role="user", original_value="test", converted_value="ImageData", converted_value_data_type="video_path" | ||
) | ||
] | ||
invalid_request = PromptRequestResponse(request_pieces=request_pieces) | ||
|
||
with pytest.raises(ValueError, match="This target only supports text and image_path."): | ||
aws_target._validate_request(prompt_request=invalid_request) | ||
|
||
|
||
@pytest.mark.skipif(not is_boto3_installed(), reason="boto3 is not installed") | ||
@pytest.mark.asyncio | ||
async def test_complete_chat_async(aws_target): | ||
with patch("boto3.client", new_callable=MagicMock) as mock_boto: | ||
mock_client = mock_boto.return_value | ||
mock_client.invoke_model.return_value = { | ||
"body": MagicMock(read=MagicMock(return_value=json.dumps({"content": [{"text": "Test Response"}]}))) | ||
} | ||
|
||
response = await aws_target._complete_chat_async( | ||
messages=[ChatMessageListDictContent(role="user", content=[{"type": "text", "text": "Test input"}])] | ||
) | ||
|
||
assert response == "Test Response" |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.