Skip to content

Commit 76a3614

Browse files
committed
fix: properly handle tool_use messages in conversation
1 parent 0183608 commit 76a3614

File tree

5 files changed

+103
-8
lines changed

5 files changed

+103
-8
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,4 +159,5 @@ cython_debug/
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160160
.idea/
161161

162-
Config
162+
Config
163+
.vscode/launch.json

src/api/app.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ async def health():
4545

4646
@app.exception_handler(RequestValidationError)
4747
async def validation_exception_handler(request, exc):
48+
logger = logging.getLogger(__name__)
49+
50+
# Log essential info only - avoid sensitive data and performance overhead
51+
logger.warning(
52+
"Request validation failed: %s %s - %s",
53+
request.method,
54+
request.url.path,
55+
str(exc).split('\n')[0] # First line only
56+
)
57+
4858
return PlainTextResponse(str(exc), status_code=400)
4959

5060

src/api/models/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import time
23
import uuid
34
from abc import ABC, abstractmethod
@@ -14,6 +15,8 @@
1415
Error,
1516
)
1617

18+
logger = logging.getLogger(__name__)
19+
1720

1821
class BaseChatModel(ABC):
1922
"""Represent a basic chat model
@@ -46,6 +49,7 @@ def generate_message_id() -> str:
4649
@staticmethod
4750
def stream_response_to_bytes(response: ChatStreamResponse | Error | None = None) -> bytes:
4851
if isinstance(response, Error):
52+
logger.error("Stream error: %s", response.error.message if response.error else "Unknown error")
4953
data = response.model_dump_json()
5054
elif isinstance(response, ChatStreamResponse):
5155
# to populate other fields when using exclude_unset=True

src/api/models/bedrock.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
ResponseFunction,
3535
TextContent,
3636
ToolCall,
37+
ToolContent,
3738
ToolMessage,
3839
Usage,
3940
UserMessage,
@@ -48,7 +49,15 @@
4849

4950
logger = logging.getLogger(__name__)
5051

51-
config = Config(connect_timeout=60, read_timeout=120, retries={"max_attempts": 1})
52+
config = Config(
53+
connect_timeout=60, # Connection timeout: 60 seconds
54+
read_timeout=900, # Read timeout: 15 minutes (suitable for long streaming responses)
55+
retries={
56+
'max_attempts': 8, # Maximum retry attempts
57+
'mode': 'adaptive' # Adaptive retry mode
58+
},
59+
max_pool_connections=50 # Maximum connection pool size
60+
)
5261

5362
bedrock_runtime = boto3.client(
5463
service_name="bedrock-runtime",
@@ -177,6 +186,7 @@ def validate(self, chat_request: ChatRequest):
177186
# check if model is supported
178187
if chat_request.model not in bedrock_model_list.keys():
179188
error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models"
189+
logger.error("Unsupported model: %s", chat_request.model)
180190

181191
if error:
182192
raise HTTPException(
@@ -204,13 +214,13 @@ async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
204214
# Run the blocking boto3 call in a thread pool
205215
response = await run_in_threadpool(bedrock_runtime.converse, **args)
206216
except bedrock_runtime.exceptions.ValidationException as e:
207-
logger.error("Validation Error: " + str(e))
217+
logger.error("Bedrock validation error for model %s: %s", chat_request.model, str(e))
208218
raise HTTPException(status_code=400, detail=str(e))
209219
except bedrock_runtime.exceptions.ThrottlingException as e:
210-
logger.error("Throttling Error: " + str(e))
220+
logger.warning("Bedrock throttling for model %s: %s", chat_request.model, str(e))
211221
raise HTTPException(status_code=429, detail=str(e))
212222
except Exception as e:
213-
logger.error(e)
223+
logger.error("Bedrock invocation failed for model %s: %s", chat_request.model, str(e))
214224
raise HTTPException(status_code=500, detail=str(e))
215225
return response
216226

@@ -270,6 +280,7 @@ async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
270280
# return an [DONE] message at the end.
271281
yield self.stream_response_to_bytes()
272282
except Exception as e:
283+
logger.error("Stream error for model %s: %s", chat_request.model, str(e))
273284
error_event = Error(error=ErrorMessage(message=str(e)))
274285
yield self.stream_response_to_bytes(error_event)
275286

@@ -317,7 +328,16 @@ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]:
317328
}
318329
)
319330
elif isinstance(message, AssistantMessage):
320-
if message.content.strip():
331+
# Check if message has content that's not empty
332+
has_content = False
333+
if isinstance(message.content, str):
334+
has_content = message.content.strip() != ""
335+
elif isinstance(message.content, list):
336+
has_content = len(message.content) > 0
337+
elif message.content is not None:
338+
has_content = True
339+
340+
if has_content:
321341
# Text message
322342
messages.append(
323343
{
@@ -349,14 +369,18 @@ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]:
349369
# Bedrock does not support tool role,
350370
# Add toolResult to content
351371
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
372+
373+
# Handle different content formats from OpenAI SDK
374+
tool_content = self._extract_tool_content(message.content)
375+
352376
messages.append(
353377
{
354378
"role": "user",
355379
"content": [
356380
{
357381
"toolResult": {
358382
"toolUseId": message.tool_call_id,
359-
"content": [{"text": message.content}],
383+
"content": [{"text": tool_content}],
360384
}
361385
}
362386
],
@@ -368,6 +392,57 @@ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]:
368392
continue
369393
return self._reframe_multi_payloard(messages)
370394

395+
def _extract_tool_content(self, content) -> str:
396+
"""Extract text content from various OpenAI SDK tool message formats.
397+
398+
Handles:
399+
- String content (legacy format)
400+
- List of content objects (OpenAI SDK 1.91.0+)
401+
- Nested JSON structures within text content
402+
"""
403+
try:
404+
if isinstance(content, str):
405+
return content
406+
407+
if isinstance(content, list):
408+
text_parts = []
409+
for i, item in enumerate(content):
410+
if isinstance(item, dict):
411+
# Handle dict with 'text' field
412+
if "text" in item:
413+
item_text = item["text"]
414+
if isinstance(item_text, str):
415+
# Try to parse as JSON if it looks like JSON
416+
if item_text.strip().startswith('{') and item_text.strip().endswith('}'):
417+
try:
418+
parsed_json = json.loads(item_text)
419+
# Convert JSON object to readable text
420+
text_parts.append(json.dumps(parsed_json, indent=2))
421+
except json.JSONDecodeError:
422+
# Silently fallback to original text
423+
text_parts.append(item_text)
424+
else:
425+
text_parts.append(item_text)
426+
else:
427+
text_parts.append(str(item_text))
428+
else:
429+
# Handle other dict formats - convert to JSON string
430+
text_parts.append(json.dumps(item, indent=2))
431+
elif hasattr(item, 'text'):
432+
# Handle ToolContent objects
433+
text_parts.append(item.text)
434+
else:
435+
# Convert any other type to string
436+
text_parts.append(str(item))
437+
return "\n".join(text_parts)
438+
439+
# Fallback for any other type
440+
return str(content)
441+
except Exception as e:
442+
logger.warning("Tool content extraction failed: %s", str(e))
443+
# Return a safe fallback
444+
return str(content) if content is not None else ""
445+
371446
def _reframe_multi_payloard(self, messages: list) -> list:
372447
"""Receive messages and reformat them to comply with the Claude format
373448

src/api/schema.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ class ImageContent(BaseModel):
4545
image_url: ImageUrl
4646

4747

48+
class ToolContent(BaseModel):
49+
type: Literal["text"] = "text"
50+
text: str
51+
52+
4853
class SystemMessage(BaseModel):
4954
name: str | None = None
5055
role: Literal["system"] = "system"
@@ -66,7 +71,7 @@ class AssistantMessage(BaseModel):
6671

6772
class ToolMessage(BaseModel):
6873
role: Literal["tool"] = "tool"
69-
content: str
74+
content: str | list[ToolContent] | list[dict]
7075
tool_call_id: str
7176

7277

0 commit comments

Comments
 (0)